{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 的士调度 Taxi-v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import gym"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 环境使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "观察空间 = Discrete(500)\n",
      "动作空间 = Discrete(6)\n",
      "状态数量 = 500\n",
      "动作数量 = 6\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('Taxi-v2')\n",
    "env.seed(0)\n",
    "print('观察空间 = {}'.format(env.observation_space))\n",
    "print('动作空间 = {}'.format(env.action_space))\n",
    "print('状态数量 = {}'.format(env.observation_space.n))\n",
    "print('动作数量 = {}'.format(env.action_space.n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 1 1 2\n",
      "的士位置 = (0, 1)\n",
      "乘客位置 = (0, 4)\n",
      "目标位置 = (4, 0)\n",
      "+---------+\n",
      "|R:\u001b[43m \u001b[0m| : :\u001b[34;1mG\u001b[0m|\n",
      "| : : : : |\n",
      "| : : : : |\n",
      "| | : | : |\n",
      "|\u001b[35mY\u001b[0m| : |B: |\n",
      "+---------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "state = env.reset()\n",
    "taxirow, taxicol, passloc, destidx = env.unwrapped.decode(state)\n",
    "print(taxirow, taxicol, passloc, destidx)\n",
    "print('的士位置 = {}'.format((taxirow, taxicol)))\n",
    "print('乘客位置 = {}'.format(env.unwrapped.locs[passloc]))\n",
    "print('目标位置 = {}'.format(env.unwrapped.locs[destidx]))\n",
    "env.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(126, -1, False, {'prob': 1.0})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.step(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------+\n",
      "|R: | : :\u001b[34;1mG\u001b[0m|\n",
      "| :\u001b[43m \u001b[0m: : : |\n",
      "| : : : : |\n",
      "| | : | : |\n",
      "|\u001b[35mY\u001b[0m| : |B: |\n",
      "+---------+\n",
      "  (South)\n"
     ]
    }
   ],
   "source": [
    "env.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SARSA 算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SARSAAgent:\n",
    "    def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=.01):\n",
    "        self.gamma = gamma\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epsilon = epsilon\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        \n",
    "    def decide(self, state):\n",
    "        if np.random.uniform() > self.epsilon:\n",
    "            action = self.q[state].argmax()\n",
    "        else:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        return action\n",
    "    \n",
    "    def learn(self, state, action, reward, next_state, done, next_action):\n",
    "        u = reward + self.gamma * \\\n",
    "                self.q[next_state, next_action] * (1. - done)\n",
    "        td_error = u - self.q[state, action]\n",
    "        self.q[state, action] += self.learning_rate * td_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def play_sarsa(env, agent, train=False, render=False):\n",
    "    episode_reward = 0\n",
    "    observation = env.reset()\n",
    "    action = agent.decide(observation)\n",
    "    while True:\n",
    "        if render:\n",
    "            env.render()\n",
    "        next_observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        next_action = agent.decide(next_observation) # 终止状态时此步无意义\n",
    "        if train:\n",
    "            agent.learn(observation, action, reward, next_observation,\n",
    "                    done, next_action)\n",
    "        if done:\n",
    "            break\n",
    "        observation, action = next_observation, next_action\n",
    "    return episode_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均回合奖励 = 671 / 100 = 6.71\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD8CAYAAAB6paOMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8VfWd//HXJwsJ+76HQNCobEIxgKC4oiBW6TJVbKvYWrEWuk1/08rYqbX+cJx2Wqd2rB3aoTN2o1Rty09pKVjXtsiisskWFksEDJshiKz5/P64J/EmuTc3yU1yk3vez8fjPjj3c8495/u93JzPOd/vOedr7o6IiIRXRqoLICIiqaVEICISckoEIiIhp0QgIhJySgQiIiGnRCAiEnJKBCIiIadEICISckoEIiIhl5XqAtRHr169fMiQIakuhohIm7J27dqD7t470XJtIhEMGTKENWvWpLoYIiJtipm9WZ/l1DQkIhJySgQiIiGnRCAiEnJKBCIiIadEICISckoEIiIhp0QgIhJySgR1ePPQu9XeV1Q4fz90POHnTp+tYM/h95fbffBdXtp+AICNb5Wxbs87lJafYOHLuzhztoLD757iibUllB49wYnTZ9lX9l6tde4vO8FTr5Zw5N1TdW77QPlJSstPUHr0RLX4c1tK2bL/KNvfLqe0/ASb9x3lneOnOHTsJGXvneZsULddB98levjSV3Ye4t2TZwB4++gJjp86U2ubp85UUHIkUt+t+8tZufNQrWWi61V2/DRb9h9l2ab9AGwoKePnK9+ktPxErc/FUlHhLF69h1NnKmrNOxvj/6jsvdMcOnYSgIPHTlJ+4nTM9b60/QC7Dkb+z5du2MeGkjKOnzpT7f/S3at+FyfPnOXnK9+s9jtZ++YRNu0tq7be3VHf6e6D77LzwDHeO3WWzfuO8srOQ6zadbjWck+v38uhYyerrfv5raX86IUdtf5vo7k7T6wt4b1TZ4HI/1/ldF12B/WO9f0BlAb/9zV/n5XL/3XHQZ7bWkpFRaT8peUnWPDijqr1xlJ2/DSHE/yeIfI9L16zh5rD6pYePcGxk2fYX3aiXnWM5UD5ScqOn2bx6j08u/nthMvvOXycM2dr/+7iOXjsJL977S3mPbWB/WUnKC4tZ+2bR2p9L3/dcZDi0sjvYn/ZCZ5Zv4/X/n6EV3YeYtvb5Q2uV0NZqsYsNrNpwPeBTOAn7v5QvGWLioq8pW8oe35rKbf/dDX/90Mj+eTFg9lx4BhXf/eFqvkLbr2Id947TcmR93jk2e1kZhg3jh7AM+v3cSr4oVx1QR/+vKW0RctdU1739pQcqZ1YEunVKYeDwc4zlp4d23GoHn/EDdE+O5MRA7qw5s0jTbK+QT3aM7hHR14uPtgk60sXk87pyV931E7WDTG0d0d2Hoi/k28tPjo2jydfLUl1MZK2+6HrG/U5M1vr7kUJl0tFIjCzTGAbcA1QAqwGbnH3N2Itn4pE8OVfv85vX3sLgC65WRw9Uf1IuHNOFuUnax8di4g0teZOBKlqGhoPFLv7Tnc/BSwCZqSoLDFFn2rWTAKAkoCIpI1UJYKBwJ6o9yVBLKUqKrxWO6RIU5p0Ts+Y8V6dclq4JNW3PWpg15Rtvyl9bdoFfPOG4akuBpecG/v/OZZvzRhBtw7ZzViaxFL10DmLEau2Bzaz2cBsgPz8/GYv0NtHTzDhwWd58MOj+PiE5t9eOhmd15V1JWWJF6zD4rsmctN//S3mvMI+nRhX0INfvvL3pLbRWF+bdgH/9sctVe/vu2E4G94q46lXI02HN44ewAMzRvLGvqMU9OpI7845nPPPSwG4c3IBX7i6kFHf/BMAv/jMBB55tphh/Tsz+2drAVh4exF53Ttw7cMvVtvu1BF9GdCtPT/9y24uP683HyvKY+4vX6tXmaObEobc80zcemVmwOzLzmHb2+VV2/+Pm8fwpV+/XrXczgenU3LkPS77znNMHdGXK87vw4NLN1MenCl3yc3i0sJeLN2wny9NKeS2iUPYsu8oFw3pTk5WZtX2byrKY/GaEuZceQ6b95Xz5y2lXNCvM1v2V+8M3T7/Ogrv/QMAL/zTFfx5SylPr9/HyAFduH/GSErLTzB+/rMM7d2RP3xxMqt2HebW/14FwN1XnAPA7ZcUVK1v0aq/c89TG5gyrA8PfmQU1z/yMj+9fRwf/MHLAJhBzeO/zd+aRvt2mUDkQoeHl2/j81cX0imn+i5z5c5D7H3nPf5x8ToA+nfN5fdzLqFPl1zcnYJ5SxkxoAub9h6t1sRcs6nntolDeOz5HfzbH7ew9utT2LK/nKdefYsnXy1hzKBuMf//mlKq+ggmAt9096nB+3kA7v6vsZZvzj6CE6fPkpVhrHnzCDMXrGR8QQ8W3zWRu362hmWbEl9FIPDFqwv5/rPbG/y5n90xngkFPTGD7MyMuDusOycXcO/1kaO8WMs89JFR3PPUBiCy0zp1toIL/uWPAHzm0gImDO3Ja38/wg+f3wHAwzePpk/nXAZ0a88nfrySC/p3oUtuFr97fW/VOpd+YTLTH3kJiPzRVm43+g949e7DjBrYldzszFpl2n3wXZ7dUsodl0Z2SMWlx+jQLpMB3doDVO0k7rpsKPOmDwMiV/j8YeN+tr1dzr9+ZBSDe3YEIr/RnKwMzKzq93qmwpm5YCWzLxvKO8dPM7R3R06dqcAMjp86y9QR/arKUhmv3LkCPHn3JC4a3L3q/eZ9R7nu+y/RKSeLjfdPrfZdV9Z5ze7DDOvfhY7BznDZpv386IUdPHX3JH6ztoSvPrGe788cw4wx1U/uN+0to1+XXDrmZPHoc8XMufJczlY4W/aXMzqvK0++WsLXntzA/TeOoFenHK6/sD8nTp+lXWYGGRmxjhnhJy/tZOqIfgzq0QGAx57fwfa3y/nezWNiLl9T5fdfNLh7tYsTPjo2j4c+OorszIY1lizdsI/uHdoxscYZ37o97zC4Zwf2vnOC3p1zGDd/BRfmdWXJ3EtjlunkmYpqv6cj756iW4dszGJ/D4nUt48gVWcEq4FCMysA3gJmAh9vyQLMePQvrNvzDgAX5nXln4M/xkp/S/KqijD4/ZxLOHjsJFec36fORPD/5l7KDf/5cq345MLqj0k/p3dHdsS4EuWuy8+pmn7xn67ksu88V/X+K9ecx8zx+WRkGC9sO0BGhpGbkVmVnDrkZHHN8L5cM7xvVSL48Afyqj7/13lXV01XJoIdD04nM84OKNq4IT3izhvSq2NVEgA4t0+navPNrNZR4YShPZkwtHaTQvSOoXI6KxN+N+eShGUEaJdVfac2cWhPxuZXP8oc2rsjYwZ14+vXv/938F+3XkTPju2q3hfVqO/UEf2qEs7HLsojv0cHJhTU/k5GDHi/2ekr155fNV2ZiG4el8+HP5BXrZyxkmu0z0weWu195ZlAfZkZT3/+UvJ7duCVnYe58/HIgWbX9tkNTgIA00f1jxkfHRzNd+sQ+R6L518Xd6duZrXq3T3q+29OKekjcPczwFxgGbAZWOzum5p7u8Wl5WwNTkMrkwDA+pIy7gpO0yvF6iBu66aP6pd4oRh+fseEmPHRg7px9bC+CXeavTvXbv8eH2Mn+tTd7+/Yfn7HBHKCHUP0H0fNv6G5V50LwE1Fg3j042Or4h1zIp/p2r7hba/x6lOP3NCqdWiXyfghPfjV7Itr7YxysjL53ZxLqu3sp47oV2vnH4+ZcfHQno0+cq2ZrFrCyIFd6ZKbzTXD+1YlwEYWv96yMjPqdZDR0lI2MI27LwWWtuQ2p3wv0gYa61Kssvdi32SUTh6YMZKlG/bHnHfZeb15cduBmPMyMuCW8YP41ar3+/fnXXdBvbZZNLg7FTGaHwt6dawV69ohm998diL/85fd1TpVo/9uBnRrz5RhfVhXUsbCWePi7nhun1RAhhm3TRxcr3Im8vLXrqR9gqPU1u6Nb01LdRFarUvO7QVQrUktTNrECGUtqfXl6qbTs1MOGQYVMbqFbirKq5YIfnDLB/jxSztZX1JGhhlfnXpBtUQQz81Fg/j1mveXe+LuSdXuGB5f0INVuw7HPfIaN6RHrSYXi/pfycwwfjJrXMJytMvKqNV88D+fGkf/ru3jfmbpFyZz6N3YN9Hlde+QcJvSdg3r36XR1+qnAyWCGGo+WiKd5PfowO4YjxC4flR/5vL+FSk3jB7A9tJjrC8pY0DX9lUdhPH8z6fGMbhnR37y0s5a8/p0zmVsfjduHD2AnOzMqkcqtLQrzu9T5/zhA7q0UElEWhclghjS+ZEEGcGheLcO2bxz/DTzrruA02crqjWxvPov1wCRq4E+dlFe1ZUZn5iQzy/iXMKZaCf71Oci7f+/WhX5fEMuVmvudtt47p0+jMK+nRIvKNLGKRHEYGnYQPTEZycC7+9UH755DC9uO8Cdk4fWukSvR3ClQmaGVSUBqHGjRwtol5XByTMVDUoaTenOy4YmXkgkDSgR1JCqo8+m8Ld5V/HzlW9W3awTrfLqj8ozggFd23PfDSMatP7oHXKy++b6fM9PfHYSf9y4v+rGHhFpHkoENazcebhej8ZtjbrkZvNPUyNX8yxZt5cv/Kr2XaiViSDWlTyJteyh+fn9OnN+v84tuk2RMNJ4BDFse/tYqovQKNFH2WPyYt+WXrlMss0tbfjESURqUCJIIxlRmcDjHL2PDhJE59yGnwzqeXwi6UmJIE0N6Na+2rNkKt0/YwS//dykap3AlRLdMKVEIJKelAjSVHZmBk/ePQmA2ycNqYrnZmfygfzaCQJg6Rcn872bRtdr/ZV3YtblEzGe4toh6Pjt0ohHP4hI81BncRqJ9QyThtwtWdCrY8xHP1Tq1jGy8/7BLR9gZILn1z/woZHcenHtxzvccOEADh07pUd9i7QiSgRppDFPTWyIL085j7zuHbg+zpMW6yMjw/h01FM5RST11DTUBn1/Zv2eud7UcrMzufXiwXGfES8ibZMSQRs0McYz60VEGkuJoC1qxQfklQOrDO+vG8FE2opQ9hGcOVuR6iKkrRtHD2DSOT1TOhi7iDRMUmcEZvYxM9tkZhVmVlRj3jwzKzazrWY2NSo+LYgVm9k9yWy/sb63fFsqNttkYj0UrzU12ysJiLQtyTYNbQQ+ArwYHTSz4UTGIR4BTAN+aGaZZpYJPApcBwwHbgmWbVEb3ipr6U02qVjD+m1+QKNPiUjjJNU05O6bgVjDBc4AFrn7SWCXmRUD44N5xe6+M/jcomDZN5IpR0M1dlzV1uAr15xXNQ5v1/bZVUNs5mTpCZ0i0jjN1Vk8EIge17AkiMWL12Jms81sjZmtOXAg9li66ejxT4+vc/6Y/Mizglb84+X8+SuXt0SRRCTNJUwEZrbCzDbGeM2o62MxYl5HvHbQfYG7F7l7Ue/evRMVs0Fa8/nAZefFruuEgsh4ApVH/uf26URPtcWLSBNI2DTk7lMasd4SYFDU+zxgbzAdL95i9hypPWZva/fjWUX8/vW9jBsS+zlBIiKN1VxNQ0uAmWaWY2YFQCGwClgNFJpZgZm1I9KhvKSZyhDXzgOtc3D6r18/LO68LrnZ3Hrx4DbdvyEirVNSncVm9mHgB0Bv4Bkze93dp7r7JjNbTKQT+Awwx93PBp+ZCywDMoGF7r4pqRqkkc9M1hi5ItLykr1q6LfAb+PMmw/MjxFfCixNZrthUjS4O/27tU91MUQkjYXyzuK25IlgTAERkeaiZw2JiIScEoGISMgpEYiIhJwSgYhIyIWis9jdeXjF9lQXo04DuuamuggiElKhSAS7Dx3nkWdbbyIY1KM9L331qlQXQ0RCKhRNQxUe83FGrcYV5/VJdRFEJMRCkQhas8wM4xs3tPiQDCIiVZQIUmjKsL7seHA62Zn6bxCR1NEeKIUuObdnqosgIhKOzuLW6vZJQ5Jex4JbL6L8xJnkCyMioaVEkALTR/XjjkuHNskjpa8d0a8JSiQiYaamoRTo2yWXiwZrgBkRaR2UCEREQi6pRGBm3zGzLWa23sx+a2bdoubNM7NiM9tqZlOj4tOCWLGZ3ZPM9utdzpbYSAO08tsaRCRkkj0jWA6MdPcLgW3APAAzG05kGMoRwDTgh2aWaWaZwKPAdcBw4JZgWRERSZGkEoG7/8ndKy9ZWUlkMHqAGcAidz/p7ruAYmB88Cp2953ufgpYFCwbKhp2WERak6bsI/g08IdgeiCwJ2peSRCLF282bx89wVXffaE5NyEi0qYlTARmtsLMNsZ4zYha5l4ig9T/ojIUY1VeRzzWdmeb2RozW3PgwIHENYlj1a7Djf5sXX70ybHNsl4RkZaW8D4Cd59S13wzmwV8ELjavaobtAQYFLVYHrA3mI4Xr7ndBcACgKKiolbXvTqsf5dUF0FEpEkke9XQNOBrwI3ufjxq1hJgppnlmFkBUAisAlYDhWZWYGbtiHQoL0mmDKmS36MDd04uSHUxRESSlmwfwX8CnYHlZva6mf0IwN03AYuBN4A/AnPc/WzQsTwXWAZsBhYHy7Y5Zsa918e/4KlTjm7aFpG2Iam9lbufW8e8+cD8GPGlwNJkttsWXH5+b55Zvy/VxRARSUh3FrdiegyFiLQEJYJW7Mm7J6W6CCISAqFOBEN7d6z3so9/ejw/v2NCvZfXPWMi0laEukczO6P+efCy83o3aN2t7npXEZE4Qn1GICIiIU8ELfnMn8yM9zdmajgSkVYk1ImgMY+D/s1nJ/LMFy5NuFzNXX27zAy+fv2wyHbVcCQirUio+wgaY9yQHvVarimGoRQRaQmhPiNozn11XatW05CItCahTgQaKUxEJOSJQEREQpAI6jroVzO+iEgIEkFzuzzOjWZXXhCJj69n57KISKqk/VVDdR30N0UfwU9vH0f5yTOMvv9P1eKXntub4vnXcfJMBSPuWwZAh3aRr7tTTmad6/zMpQV0zs1OvnAiIvWQ9omguWVkGB3axd6xZ2VmcPJMRdX7m4ryKD9xmlmThtS5zq9/MP44ByIiTS10TUNXXdCnarqp+giyMzMo7NMp5rzobWRlZnDX5eeQm133GYGISEtKdqjKB8xsfTA62Z/MbEAQNzN7xMyKg/ljoz4zy8y2B69ZyVagoXKy3q9yc14+qo5oEWkrkj0j+I67X+juY4CngW8E8euIjFNcCMwGHgMwsx7AfcAEYDxwn5mlxegrj31yLDPHDap6rzwgIm1FUonA3Y9Gve3I+1drzgAe94iVQDcz6w9MBZa7+2F3PwIsB6YlU4aGij5Sb8qj9nP7dOahj14YtW6lAhFpG5LuLDaz+cBtQBlwZRAeCOyJWqwkiMWLx1rvbCJnE+Tn5ze6fHW1/rTEncW6e1lEWruEZwRmtsLMNsZ4zQBw93vdfRDwC2Bu5cdirMrriNcOui9w9yJ3L+rdu2GDwtSlpXbMOh8QkbYi4RmBu0+p57p+CTxDpA+gBBgUNS8P2BvEr6gRf76e628SzdU0VJ/tiYi0RsleNVQY9fZGYEswvQS4Lbh66GKgzN33AcuAa82se9BJfG0QazbNfUNZ3O0qAYhIG5FsH8FDZnY+UAG8CXw2iC8FpgPFwHHgUwDuftjMHgBWB8t9y90PJ1mGVuXK83vz3NYDtMsK3S0aItJGJZUI3P2jceIOzIkzbyGwMJntJiN6LIDmOGp/9BNj2Xng3arHSYiItHahPmxtjqahDu2yGDmwa9OvWESkmYQuEWi8YBGR6tI+EdTc7WuYSBGR6tI+EYiISN3Clwh0QiAiUk3aJwLt90VE6pb2iaAW9RWLiFQTvkQgIiLVhC8RqK1IRKSa8CUCERGpRolARCTk0j4R1L6hTEREoqV9IhARkbqFLhHkZGWmuggiIq1K2ieCmk1B5/TpmJJyiIi0Vk2SCMzs/5iZm1mv4L2Z2SNmVmxm681sbNSys8xse/Ca1RTbb1BZ1UsgIlJN0qOnmNkg4Brg71Hh64DC4DUBeAyYYGY9iIxpXESkH3etmS1x9yPJlkNERBqnKc4IHga+SvULdGYAj3vESqCbmfUHpgLL3f1wsPNfDkxrgjK0Ws05LrKISFNIdvD6G4G33H1djVkDgT1R70uCWLx4i9Gg8iIi1SVsGjKzFUC/GLPuBf4ZuDbWx2LEvI54rO3OBmYD5OfnJypmXDVX3tJH6Eo8ItLaJUwE7j4lVtzMRgEFwDqL7O3ygFfNbDyRI/1BUYvnAXuD+BU14s/H2e4CYAFAUVGRGlhERJpJo5uG3H2Du/dx9yHuPoTITn6su+8HlgC3BVcPXQyUufs+YBlwrZl1N7PuRM4mliVfjfrTEbqISHXNdR/BUmAnUAz8GPgcgLsfBh4AVgevbwWxlPj3j41O1aZFRFqNpC8frRScFVROOzAnznILgYVNtd1E6joBGJXXtdm3nxGcggzoltvs2xIRaYwmSwQSW252Jj/8xFiKBndPdVFERGJSImgB00f1T3URRETiSvtnDYmISN3SPhHoulMRkbqlfSKoSVePiohUF7pEICIi1SkRiIiEnBKBiEjIpX0iUJ+AiEjd0j4RiIhI3UKXCPTQORGR6tI+Eeg+AhGRuqV9IhARkbopEYiIhJwSgYhIyCU7eP03zewtM3s9eE2PmjfPzIrNbKuZTY2KTwtixWZ2TzLbr48D5SebexMiIm1aUzyG+mF3//fogJkNB2YCI4ABwAozOy+Y/ShwDZGhLVeb2RJ3f6MJyhHTA08326pFRNJCc41HMANY5O4ngV1mVgyMD+YVu/tOADNbFCzbYntr0y1mIiLVNEUfwVwzW29mC4MB6QEGAnuilikJYvHiIiKSIgkTgZmtMLONMV4zgMeAc4AxwD7gu5Ufi7EqryMea7uzzWyNma05cOBAvSojIiINl7BpyN2n1GdFZvZj4OngbQkwKGp2HrA3mI4Xr7ndBcACgKKioia/L2xwzw5NvUoRkTYpqT4CM+vv7vuCtx8GNgbTS4Bfmtn3iHQWFwKriJwRFJpZAfAWkQ7ljydThsb4yz1X0SVXwzWLiEDyncXfNrMxRJp3dgN3Abj7JjNbTKQT+Awwx93PApjZXGAZkAksdPdNSZahwQZ2a1/n/NGDurFuzzstVBoRkdRKKhG4+611zJsPzI8RXwosTWa7yUj00Lmu7bPJytCVRSISHrqzuAY9nVREwkaJQEQk5JQIatAJgYiEjRJBDQ64axQDEQkPJYIadEYgImGjRBCDqcdYREJEiaAGJQERCRslAhGRkFMiqCH6fODxT4+Pu5yISLpQIqhDh3aZqS6CiEizUyKoQV0EIhI2SgQiIiEXukSQ+KognRKISLiELhHUh+4sFpEwUSIQEQk5JYIYdFOZiIRJ0onAzD5vZlvNbJOZfTsqPs/MioN5U6Pi04JYsZndk+z2k3XXZUOrvVcOEJGwSXbM4iuBGcCF7n7SzPoE8eFExiMeQWTM4hVmdl7wsUeBa4gMcL/azJa4+xvJlCMZ86YP49TZCn76l92pKoKISEolO2bx3cBD7n4SwN1Lg/gMYFEQ32VmxUDlbbrF7r4TwMwWBcu2WCLQNUMiItUl2zR0HjDZzF4xsxfMbFwQHwjsiVquJIjFi4uISIokPCMwsxVAvxiz7g0+3x24GBgHLDazocQ+sHZiJ56Y12qa2WxgNkB+fn6iYtZb947Zdc5XH4GIhE3CRODuU+LNM7O7gac8cuH9KjOrAHoROdIfFLVoHrA3mI4Xr7ndBcACgKKioia7sP9DY3QCIiISLdmmod8BVwEEncHtgIPAEmCmmeWYWQFQCKwCVgOFZlZgZu2IdCgvSbIMDZLo0lBTL4GIhEyyncULgYVmthE4BcwKzg42mdliIp3AZ4A57n4WwMzmAsuATGChu29KsgxJq3kjse4sFpEwSSoRuPsp4JNx5s0H5seILwWWJrPd5qQ+AhEJG91ZHIPuLBaRMFEiEBEJOSUC4NrhfVNdBBGRlFEiACad24u/3HMVoDuLRSR8lAgCulJIRMJKiaAGdRSLSNgoEYiIhJwSQUAtQyISVkoEMai/QETCRIlARCTklAhqMFOHsYiEixJBoG+XXM7v25n5Hx5VLT5r4mDunFyQolKJiDS/ZJ8+mjbaZWWw7MuXAfDIs9ur4vfPGJmqIomItAidEYiIhJwSgYhIyCWVCMzs12b2evDabWavR82bZ2bFZrbVzKZGxacFsWIzuyeZ7YuISPKSHZjm5sppM/suUBZMDycyDOUIYACwIhjKEuBR4Boi4xqvNrMl7v5GMuUQEZHGa5LOYotcb3kTwfjFwAxgkbufBHaZWTEwPphX7O47g88tCpZVIhARSZGm6iOYDLzt7pWX2wwE9kTNLwli8eIiIpIiCc8IzGwF0C/GrHvd/ffB9C3Ar6I/FmN5J3biifk8BzObDcwGyM/PT1RMERFppISJwN2n1DXfzLKAjwAXRYVLgEFR7/OAvcF0vHjN7S4AFgAUFRXp4T8iIs2kKZqGpgBb3L0kKrYEmGlmOWZWABQCq4DVQKGZFZhZOyIdykuaoAwxlZ843VyrFhFJG03RWTyT6s1CuPsmM1tMpBP4DDDH3c8CmNlcYBmQCSx0901NUIaYTp/ViYSISCJJJwJ3vz1OfD4wP0Z8KbA02e3WR4aeHSciklBa31msp4iKiCSW1olAZwQiIomleSJQJhARSUSJQEQk5NI6ETQ2Dwzq3h6AjjkarkFE0l9a7+kae0bw4EdGMW1kf4b179LEJRIRaX3S+oygsZ3FHdplMW1krKdqiIiknzRPBOojEBFJJK0TgfKAiEhiaZ4IlAlERBJJ60QgIiKJKRGIiIScEoGISMgpEYiIhJwSgYhIyCkRiIiEXFKJwMzGmNlKM3vdzNaY2fggbmb2iJkVm9l6Mxsb9ZlZZrY9eM1KtgKJvPy1K5t7EyIibVqyZwTfBu539zHAN4L3ANcRGae4EJgNPAZgZj2A+4AJwHjgPjPrnmQZ6pTXvQPtszObcxMiIm1asonAgcons3UF9gbTM4DHPWIl0M3M+gNTgeXuftjdjwDLgWlJliGhJXMvYVCP9jz9+Uube1MiIm1Osk8f/RKwzMz+nUhSmRTEBwJ7opYrCWLx4rWY2WwiZxPk5+cnVcjCvp156atXJbW0LOJ+AAAFJ0lEQVQOEZF0lTARmNkKINajOO8Frga+7O5PmtlNwH8DU4BYz3bwOuK1g+4LgAUARUVFMZcREZHkJUwE7j4l3jwzexz4YvD2N8BPgukSYFDUonlEmo1KgCtqxJ+vd2lFRKTJJdtHsBe4PJi+CtgeTC8BbguuHroYKHP3fcAy4Foz6x50El8bxEREJEWS7SO4E/i+mWUBJwja9IGlwHSgGDgOfArA3Q+b2QPA6mC5b7n74STLICIiSUgqEbj7y8BFMeIOzInzmYXAwmS2KyIiTUd3FouIhJwSgYhIyCkRiIiEnEWa81s3MzsAvJnEKnoBB5uoOG1F2OoctvqC6hwWydR5sLv3TrRQm0gEyTKzNe5elOpytKSw1Tls9QXVOSxaos5qGhIRCTklAhGRkAtLIliQ6gKkQNjqHLb6guocFs1e51D0EYiISHxhOSMQEZE40joRmNk0M9saDJl5T6rLkwwzW2hmpWa2MSrWw8yWB8N+Lq8c7a01DRWaDDMbZGbPmdlmM9tkZl8M4mlbbzPLNbNVZrYuqPP9QbzAzF4Jyv9rM2sXxHOC98XB/CFR65oXxLea2dTU1Kh+zCzTzF4zs6eD9+le391mtqFymN8glrrftbun5QvIBHYAQ4F2wDpgeKrLlUR9LgPGAhujYt8G7gmm7wH+LZieDvyByPgPFwOvBPEewM7g3+7BdPdU162OOvcHxgbTnYFtwPB0rndQ9k7BdDbwSlCXxcDMIP4j4O5g+nPAj4LpmcCvg+nhwW8+BygI/hYyU12/Our9j8AvgaeD9+le391ArxqxlP2u0/mMYDxQ7O473f0UsIjIEJptkru/CNR8UusM4H+D6f8FPhQVbzVDhTaWu+9z91eD6XJgM5ER7dK23kHZjwVvs4OXE3nM+xNBvGadK7+LJ4CrzcyC+CJ3P+nuu4g8CXh8C1ShwcwsD7ieYDyToPxpW986pOx3nc6JoN7DYrZhfT0yzgPBv32CeNJDhbY2QRPAB4gcIad1vYNmkteBUiJ/3DuAd9z9TLBIdPmr6hbMLwN60rbq/B/AV4GK4H1P0ru+EEnufzKztRYZlhdS+LtOdjyC1qzew2KmoaSHCm1NzKwT8CTwJXc/GjkAjL1ojFibq7e7nwXGmFk34LfAsFiLBf+26Tqb2QeBUndfa2ZXVIZjLJoW9Y1yibvvNbM+wHIz21LHss1e53Q+I4g3XGY6eTs4RST4tzSI1zVUaJv6Tswsm0gS+IW7PxWE077eAO7+DpGhXC8m0hxQeeAWXf6qugXzuxJpQmwrdb4EuNHMdhNpvr2KyBlCutYXAHffG/xbSiTZjyeFv+t0TgSrgcLg6oN2RDqWlqS4TE1tCVB5pcAs4PdR8TY/VGjQ9vvfwGZ3/17UrLStt5n1Ds4EMLP2wBQifSPPAf8QLFazzpXfxT8Af/ZIT+ISYGZwlU0BUAisapla1J+7z3P3PHcfQuRv9M/u/gnStL4AZtbRzDpXThP5PW4klb/rVPeeN+eLSG/7NiJtrPemujxJ1uVXwD7gNJEjgTuItI0+S2Ss6GeBHsGyBjwa1HsDUBS1nk8T6UgrBj6V6nolqPOlRE511wOvB6/p6Vxv4ELgtaDOG4FvBPGhRHZsxcBvgJwgnhu8Lw7mD41a173Bd7EVuC7VdatH3a/g/auG0ra+Qd3WBa9NlfumVP6udWexiEjIpXPTkIiI1IMSgYhIyCkRiIiEnBKBiEjIKRGIiIScEoGISMgpEYiIhJwSgYhIyP1/7+Egn1QoTm0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "agent = SARSAAgent(env)\n",
    "\n",
    "# 训练\n",
    "episodes = 5000\n",
    "episode_rewards = []\n",
    "for episode in range(episodes):\n",
    "    episode_reward = play_sarsa(env, agent, train=True)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    \n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "# 测试\n",
    "agent.epsilon = 0. # 取消探索\n",
    "\n",
    "episode_rewards = [play_sarsa(env, agent) for _ in range(100)]\n",
    "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),\n",
    "        len(episode_rewards), np.mean(episode_rewards)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "显示最优价值估计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-2.600026</td>\n",
       "      <td>-2.593051</td>\n",
       "      <td>-2.612598</td>\n",
       "      <td>-1.715516</td>\n",
       "      <td>4.249763</td>\n",
       "      <td>-3.772742</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-2.067980</td>\n",
       "      <td>-1.351987</td>\n",
       "      <td>-1.676313</td>\n",
       "      <td>-2.233975</td>\n",
       "      <td>7.439644</td>\n",
       "      <td>-3.734128</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-2.944770</td>\n",
       "      <td>-3.051257</td>\n",
       "      <td>-2.687768</td>\n",
       "      <td>-2.947658</td>\n",
       "      <td>1.979383</td>\n",
       "      <td>-4.174560</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-3.642111</td>\n",
       "      <td>-4.449279</td>\n",
       "      <td>-4.419784</td>\n",
       "      <td>-4.465809</td>\n",
       "      <td>-5.593674</td>\n",
       "      <td>-5.602355</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>-4.887117</td>\n",
       "      <td>-5.017114</td>\n",
       "      <td>-4.989397</td>\n",
       "      <td>-4.996423</td>\n",
       "      <td>-5.593139</td>\n",
       "      <td>-5.603846</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>-4.180342</td>\n",
       "      <td>-4.177074</td>\n",
       "      <td>-3.744672</td>\n",
       "      <td>-4.192496</td>\n",
       "      <td>-5.595174</td>\n",
       "      <td>-4.642345</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.324690</td>\n",
       "      <td>-3.624851</td>\n",
       "      <td>-3.536758</td>\n",
       "      <td>-3.606046</td>\n",
       "      <td>-4.627411</td>\n",
       "      <td>-3.774342</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>-3.701803</td>\n",
       "      <td>-5.172822</td>\n",
       "      <td>-5.172463</td>\n",
       "      <td>-5.168523</td>\n",
       "      <td>-5.593471</td>\n",
       "      <td>-5.603375</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>-3.213076</td>\n",
       "      <td>-4.705947</td>\n",
       "      <td>-4.738782</td>\n",
       "      <td>-4.737740</td>\n",
       "      <td>-6.466847</td>\n",
       "      <td>-5.602483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>-5.400523</td>\n",
       "      <td>-5.433982</td>\n",
       "      <td>-5.069151</td>\n",
       "      <td>-5.441407</td>\n",
       "      <td>-5.594126</td>\n",
       "      <td>-7.302430</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>-4.480154</td>\n",
       "      <td>-4.508876</td>\n",
       "      <td>-4.192068</td>\n",
       "      <td>-4.527803</td>\n",
       "      <td>-5.593906</td>\n",
       "      <td>-5.604151</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>-4.932966</td>\n",
       "      <td>-4.991243</td>\n",
       "      <td>-4.736773</td>\n",
       "      <td>-4.934336</td>\n",
       "      <td>-5.593555</td>\n",
       "      <td>-5.605058</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>3.663884</td>\n",
       "      <td>-0.190000</td>\n",
       "      <td>2.468896</td>\n",
       "      <td>1.529000</td>\n",
       "      <td>-0.019000</td>\n",
       "      <td>20.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>5.889241</td>\n",
       "      <td>-1.438923</td>\n",
       "      <td>-1.717338</td>\n",
       "      <td>-1.729236</td>\n",
       "      <td>-2.241977</td>\n",
       "      <td>-1.930881</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>9.496498</td>\n",
       "      <td>-1.424756</td>\n",
       "      <td>-1.494616</td>\n",
       "      <td>-1.431893</td>\n",
       "      <td>-1.900000</td>\n",
       "      <td>-1.484599</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>3.673671</td>\n",
       "      <td>-2.368791</td>\n",
       "      <td>-2.215117</td>\n",
       "      <td>-1.897817</td>\n",
       "      <td>-3.764300</td>\n",
       "      <td>-2.351717</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>-2.885966</td>\n",
       "      <td>-2.929858</td>\n",
       "      <td>-2.951218</td>\n",
       "      <td>1.628554</td>\n",
       "      <td>-3.764329</td>\n",
       "      <td>-3.773480</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>-2.501994</td>\n",
       "      <td>-2.522568</td>\n",
       "      <td>-2.479618</td>\n",
       "      <td>3.991413</td>\n",
       "      <td>-2.857404</td>\n",
       "      <td>-2.881117</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>-3.401930</td>\n",
       "      <td>-3.429584</td>\n",
       "      <td>-3.390730</td>\n",
       "      <td>1.058025</td>\n",
       "      <td>-3.764710</td>\n",
       "      <td>-3.774153</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>-3.019153</td>\n",
       "      <td>-4.403842</td>\n",
       "      <td>-4.432891</td>\n",
       "      <td>-4.351873</td>\n",
       "      <td>-5.593881</td>\n",
       "      <td>-5.603550</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>-4.201486</td>\n",
       "      <td>-4.815455</td>\n",
       "      <td>-4.836067</td>\n",
       "      <td>-4.848010</td>\n",
       "      <td>-5.593854</td>\n",
       "      <td>-5.602810</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>-1.908931</td>\n",
       "      <td>-4.022056</td>\n",
       "      <td>-4.087451</td>\n",
       "      <td>-4.082118</td>\n",
       "      <td>-5.593085</td>\n",
       "      <td>-5.602737</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>-3.651628</td>\n",
       "      <td>-3.673430</td>\n",
       "      <td>-3.741787</td>\n",
       "      <td>-2.555545</td>\n",
       "      <td>-4.701294</td>\n",
       "      <td>-3.775230</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>-4.834480</td>\n",
       "      <td>-5.300449</td>\n",
       "      <td>-5.339533</td>\n",
       "      <td>-5.271594</td>\n",
       "      <td>-5.595239</td>\n",
       "      <td>-6.445467</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>470</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>471</th>\n",
       "      <td>-5.826520</td>\n",
       "      <td>-5.361917</td>\n",
       "      <td>-5.782439</td>\n",
       "      <td>-5.769232</td>\n",
       "      <td>-7.140112</td>\n",
       "      <td>-6.494966</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>472</th>\n",
       "      <td>-3.376717</td>\n",
       "      <td>-2.575815</td>\n",
       "      <td>-3.384449</td>\n",
       "      <td>-2.926946</td>\n",
       "      <td>2.565026</td>\n",
       "      <td>-4.476473</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>473</th>\n",
       "      <td>-1.178989</td>\n",
       "      <td>-1.921930</td>\n",
       "      <td>-1.894414</td>\n",
       "      <td>-0.625845</td>\n",
       "      <td>5.827701</td>\n",
       "      <td>-3.772862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>474</th>\n",
       "      <td>-2.564208</td>\n",
       "      <td>-2.777640</td>\n",
       "      <td>-2.976987</td>\n",
       "      <td>-2.320540</td>\n",
       "      <td>2.531285</td>\n",
       "      <td>-5.158846</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>475</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>476</th>\n",
       "      <td>-2.791922</td>\n",
       "      <td>3.970145</td>\n",
       "      <td>-2.840441</td>\n",
       "      <td>-2.274948</td>\n",
       "      <td>-4.025782</td>\n",
       "      <td>-2.782034</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>477</th>\n",
       "      <td>0.231274</td>\n",
       "      <td>7.626263</td>\n",
       "      <td>-1.570958</td>\n",
       "      <td>-1.598804</td>\n",
       "      <td>-2.106045</td>\n",
       "      <td>-1.641895</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>478</th>\n",
       "      <td>-2.505515</td>\n",
       "      <td>3.947637</td>\n",
       "      <td>-2.490659</td>\n",
       "      <td>-2.514763</td>\n",
       "      <td>-4.114975</td>\n",
       "      <td>-2.478691</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>479</th>\n",
       "      <td>1.529000</td>\n",
       "      <td>-0.100000</td>\n",
       "      <td>-0.100000</td>\n",
       "      <td>3.076100</td>\n",
       "      <td>-0.019001</td>\n",
       "      <td>20.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>480</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>481</th>\n",
       "      <td>-4.846765</td>\n",
       "      <td>-4.819097</td>\n",
       "      <td>-4.906512</td>\n",
       "      <td>-4.851796</td>\n",
       "      <td>-5.594329</td>\n",
       "      <td>-6.213804</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>482</th>\n",
       "      <td>-4.502797</td>\n",
       "      <td>-4.371799</td>\n",
       "      <td>-4.422956</td>\n",
       "      <td>-4.412366</td>\n",
       "      <td>-5.593346</td>\n",
       "      <td>-5.605950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>483</th>\n",
       "      <td>-5.432701</td>\n",
       "      <td>-5.388436</td>\n",
       "      <td>-5.408471</td>\n",
       "      <td>-5.392289</td>\n",
       "      <td>-5.683992</td>\n",
       "      <td>-5.603984</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>484</th>\n",
       "      <td>-4.138971</td>\n",
       "      <td>-1.984586</td>\n",
       "      <td>-4.088252</td>\n",
       "      <td>-4.125033</td>\n",
       "      <td>-4.587827</td>\n",
       "      <td>-5.602073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>485</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>486</th>\n",
       "      <td>-4.306615</td>\n",
       "      <td>-3.876250</td>\n",
       "      <td>-4.297456</td>\n",
       "      <td>-4.244949</td>\n",
       "      <td>-5.593065</td>\n",
       "      <td>-4.667585</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>487</th>\n",
       "      <td>-3.710346</td>\n",
       "      <td>-1.718120</td>\n",
       "      <td>-3.685916</td>\n",
       "      <td>-3.654622</td>\n",
       "      <td>-3.764096</td>\n",
       "      <td>-4.654249</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>488</th>\n",
       "      <td>-4.509396</td>\n",
       "      <td>-3.901351</td>\n",
       "      <td>-4.569172</td>\n",
       "      <td>-4.530301</td>\n",
       "      <td>-4.612896</td>\n",
       "      <td>-4.602810</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>489</th>\n",
       "      <td>-6.087150</td>\n",
       "      <td>-6.185883</td>\n",
       "      <td>-6.106261</td>\n",
       "      <td>-6.018208</td>\n",
       "      <td>-7.388198</td>\n",
       "      <td>-7.396675</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>490</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>491</th>\n",
       "      <td>-5.907680</td>\n",
       "      <td>-5.618526</td>\n",
       "      <td>-5.890413</td>\n",
       "      <td>-5.854301</td>\n",
       "      <td>-6.533011</td>\n",
       "      <td>-7.403322</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>492</th>\n",
       "      <td>-3.699239</td>\n",
       "      <td>-3.726140</td>\n",
       "      <td>-3.718504</td>\n",
       "      <td>0.281332</td>\n",
       "      <td>-3.764229</td>\n",
       "      <td>-3.773421</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>493</th>\n",
       "      <td>-2.653038</td>\n",
       "      <td>-2.582561</td>\n",
       "      <td>-2.673425</td>\n",
       "      <td>2.907245</td>\n",
       "      <td>-3.764590</td>\n",
       "      <td>-3.773112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>494</th>\n",
       "      <td>-3.459074</td>\n",
       "      <td>-3.356164</td>\n",
       "      <td>-3.352881</td>\n",
       "      <td>1.113477</td>\n",
       "      <td>-3.764699</td>\n",
       "      <td>-3.773177</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>495</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>496</th>\n",
       "      <td>-2.595107</td>\n",
       "      <td>-2.574089</td>\n",
       "      <td>-2.625771</td>\n",
       "      <td>-2.047315</td>\n",
       "      <td>-2.761176</td>\n",
       "      <td>-3.773025</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>497</th>\n",
       "      <td>-1.267358</td>\n",
       "      <td>-1.156897</td>\n",
       "      <td>-1.273078</td>\n",
       "      <td>-1.261785</td>\n",
       "      <td>-3.591614</td>\n",
       "      <td>-1.909000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>498</th>\n",
       "      <td>-2.352192</td>\n",
       "      <td>-2.303140</td>\n",
       "      <td>-2.291246</td>\n",
       "      <td>-2.269872</td>\n",
       "      <td>-2.869774</td>\n",
       "      <td>-2.762138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>499</th>\n",
       "      <td>-0.190000</td>\n",
       "      <td>-0.199000</td>\n",
       "      <td>-0.190000</td>\n",
       "      <td>-0.010000</td>\n",
       "      <td>-1.900000</td>\n",
       "      <td>-1.909000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            0         1         2         3         4          5\n",
       "0    0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "1   -2.600026 -2.593051 -2.612598 -1.715516  4.249763  -3.772742\n",
       "2   -2.067980 -1.351987 -1.676313 -2.233975  7.439644  -3.734128\n",
       "3   -2.944770 -3.051257 -2.687768 -2.947658  1.979383  -4.174560\n",
       "4   -3.642111 -4.449279 -4.419784 -4.465809 -5.593674  -5.602355\n",
       "5    0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "6   -4.887117 -5.017114 -4.989397 -4.996423 -5.593139  -5.603846\n",
       "7   -4.180342 -4.177074 -3.744672 -4.192496 -5.595174  -4.642345\n",
       "8    0.324690 -3.624851 -3.536758 -3.606046 -4.627411  -3.774342\n",
       "9   -3.701803 -5.172822 -5.172463 -5.168523 -5.593471  -5.603375\n",
       "10   0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "11  -3.213076 -4.705947 -4.738782 -4.737740 -6.466847  -5.602483\n",
       "12  -5.400523 -5.433982 -5.069151 -5.441407 -5.594126  -7.302430\n",
       "13  -4.480154 -4.508876 -4.192068 -4.527803 -5.593906  -5.604151\n",
       "14  -4.932966 -4.991243 -4.736773 -4.934336 -5.593555  -5.605058\n",
       "15   0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "16   3.663884 -0.190000  2.468896  1.529000 -0.019000  20.000000\n",
       "17   5.889241 -1.438923 -1.717338 -1.729236 -2.241977  -1.930881\n",
       "18   9.496498 -1.424756 -1.494616 -1.431893 -1.900000  -1.484599\n",
       "19   3.673671 -2.368791 -2.215117 -1.897817 -3.764300  -2.351717\n",
       "20   0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "21  -2.885966 -2.929858 -2.951218  1.628554 -3.764329  -3.773480\n",
       "22  -2.501994 -2.522568 -2.479618  3.991413 -2.857404  -2.881117\n",
       "23  -3.401930 -3.429584 -3.390730  1.058025 -3.764710  -3.774153\n",
       "24  -3.019153 -4.403842 -4.432891 -4.351873 -5.593881  -5.603550\n",
       "25   0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "26  -4.201486 -4.815455 -4.836067 -4.848010 -5.593854  -5.602810\n",
       "27  -1.908931 -4.022056 -4.087451 -4.082118 -5.593085  -5.602737\n",
       "28  -3.651628 -3.673430 -3.741787 -2.555545 -4.701294  -3.775230\n",
       "29  -4.834480 -5.300449 -5.339533 -5.271594 -5.595239  -6.445467\n",
       "..        ...       ...       ...       ...       ...        ...\n",
       "470  0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "471 -5.826520 -5.361917 -5.782439 -5.769232 -7.140112  -6.494966\n",
       "472 -3.376717 -2.575815 -3.384449 -2.926946  2.565026  -4.476473\n",
       "473 -1.178989 -1.921930 -1.894414 -0.625845  5.827701  -3.772862\n",
       "474 -2.564208 -2.777640 -2.976987 -2.320540  2.531285  -5.158846\n",
       "475  0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "476 -2.791922  3.970145 -2.840441 -2.274948 -4.025782  -2.782034\n",
       "477  0.231274  7.626263 -1.570958 -1.598804 -2.106045  -1.641895\n",
       "478 -2.505515  3.947637 -2.490659 -2.514763 -4.114975  -2.478691\n",
       "479  1.529000 -0.100000 -0.100000  3.076100 -0.019001  20.000000\n",
       "480  0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "481 -4.846765 -4.819097 -4.906512 -4.851796 -5.594329  -6.213804\n",
       "482 -4.502797 -4.371799 -4.422956 -4.412366 -5.593346  -5.605950\n",
       "483 -5.432701 -5.388436 -5.408471 -5.392289 -5.683992  -5.603984\n",
       "484 -4.138971 -1.984586 -4.088252 -4.125033 -4.587827  -5.602073\n",
       "485  0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "486 -4.306615 -3.876250 -4.297456 -4.244949 -5.593065  -4.667585\n",
       "487 -3.710346 -1.718120 -3.685916 -3.654622 -3.764096  -4.654249\n",
       "488 -4.509396 -3.901351 -4.569172 -4.530301 -4.612896  -4.602810\n",
       "489 -6.087150 -6.185883 -6.106261 -6.018208 -7.388198  -7.396675\n",
       "490  0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "491 -5.907680 -5.618526 -5.890413 -5.854301 -6.533011  -7.403322\n",
       "492 -3.699239 -3.726140 -3.718504  0.281332 -3.764229  -3.773421\n",
       "493 -2.653038 -2.582561 -2.673425  2.907245 -3.764590  -3.773112\n",
       "494 -3.459074 -3.356164 -3.352881  1.113477 -3.764699  -3.773177\n",
       "495  0.000000  0.000000  0.000000  0.000000  0.000000   0.000000\n",
       "496 -2.595107 -2.574089 -2.625771 -2.047315 -2.761176  -3.773025\n",
       "497 -1.267358 -1.156897 -1.273078 -1.261785 -3.591614  -1.909000\n",
       "498 -2.352192 -2.303140 -2.291246 -2.269872 -2.869774  -2.762138\n",
       "499 -0.190000 -0.199000 -0.190000 -0.010000 -1.900000  -1.909000\n",
       "\n",
       "[500 rows x 6 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(agent.q)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "显示最优策略估计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>470</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>471</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>472</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>473</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>474</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>475</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>476</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>477</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>478</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>479</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>480</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>481</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>482</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>483</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>484</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>485</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>486</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>487</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>488</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>489</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>490</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>491</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>492</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>493</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>494</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>495</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>496</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>497</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>498</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>499</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       0    1    2    3    4    5\n",
       "0    1.0  0.0  0.0  0.0  0.0  0.0\n",
       "1    0.0  0.0  0.0  0.0  1.0  0.0\n",
       "2    0.0  0.0  0.0  0.0  1.0  0.0\n",
       "3    0.0  0.0  0.0  0.0  1.0  0.0\n",
       "4    1.0  0.0  0.0  0.0  0.0  0.0\n",
       "5    1.0  0.0  0.0  0.0  0.0  0.0\n",
       "6    1.0  0.0  0.0  0.0  0.0  0.0\n",
       "7    0.0  0.0  1.0  0.0  0.0  0.0\n",
       "8    1.0  0.0  0.0  0.0  0.0  0.0\n",
       "9    1.0  0.0  0.0  0.0  0.0  0.0\n",
       "10   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "11   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "12   0.0  0.0  1.0  0.0  0.0  0.0\n",
       "13   0.0  0.0  1.0  0.0  0.0  0.0\n",
       "14   0.0  0.0  1.0  0.0  0.0  0.0\n",
       "15   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "16   0.0  0.0  0.0  0.0  0.0  1.0\n",
       "17   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "18   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "19   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "20   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "21   0.0  0.0  0.0  1.0  0.0  0.0\n",
       "22   0.0  0.0  0.0  1.0  0.0  0.0\n",
       "23   0.0  0.0  0.0  1.0  0.0  0.0\n",
       "24   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "25   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "26   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "27   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "28   0.0  0.0  0.0  1.0  0.0  0.0\n",
       "29   1.0  0.0  0.0  0.0  0.0  0.0\n",
       "..   ...  ...  ...  ...  ...  ...\n",
       "470  1.0  0.0  0.0  0.0  0.0  0.0\n",
       "471  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "472  0.0  0.0  0.0  0.0  1.0  0.0\n",
       "473  0.0  0.0  0.0  0.0  1.0  0.0\n",
       "474  0.0  0.0  0.0  0.0  1.0  0.0\n",
       "475  1.0  0.0  0.0  0.0  0.0  0.0\n",
       "476  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "477  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "478  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "479  0.0  0.0  0.0  0.0  0.0  1.0\n",
       "480  1.0  0.0  0.0  0.0  0.0  0.0\n",
       "481  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "482  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "483  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "484  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "485  1.0  0.0  0.0  0.0  0.0  0.0\n",
       "486  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "487  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "488  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "489  0.0  0.0  0.0  1.0  0.0  0.0\n",
       "490  1.0  0.0  0.0  0.0  0.0  0.0\n",
       "491  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "492  0.0  0.0  0.0  1.0  0.0  0.0\n",
       "493  0.0  0.0  0.0  1.0  0.0  0.0\n",
       "494  0.0  0.0  0.0  1.0  0.0  0.0\n",
       "495  1.0  0.0  0.0  0.0  0.0  0.0\n",
       "496  0.0  0.0  0.0  1.0  0.0  0.0\n",
       "497  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "498  0.0  0.0  0.0  1.0  0.0  0.0\n",
       "499  0.0  0.0  0.0  1.0  0.0  0.0\n",
       "\n",
       "[500 rows x 6 columns]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "policy = np.eye(agent.action_n)[agent.q.argmax(axis=-1)] \n",
    "pd.DataFrame(policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 期望 SARSA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExpectedSARSAAgent:\n",
    "    def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=.01):\n",
    "        self.gamma = gamma\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epsilon = epsilon\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        self.action_n = env.action_space.n\n",
    "        \n",
    "    def decide(self, state):\n",
    "        if np.random.uniform() > self.epsilon:\n",
    "            action = self.q[state].argmax()\n",
    "        else:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        return action\n",
    "    \n",
    "    def learn(self, state, action, reward, next_state, done):\n",
    "        v = (self.q[next_state].sum() * self.epsilon + \\\n",
    "                self.q[next_state].max() * (1. - self.epsilon))\n",
    "        u = reward + self.gamma * v * (1. - done)\n",
    "        td_error = u - self.q[state, action]\n",
    "        self.q[state, action] += self.learning_rate * td_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def play_qlearning(env, agent, train=False, render=False):\n",
    "    episode_reward = 0\n",
    "    observation = env.reset()\n",
    "    while True:\n",
    "        if render:\n",
    "            env.render()\n",
    "        action = agent.decide(observation)\n",
    "        next_observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        if train:\n",
    "            agent.learn(observation, action, reward, next_observation,\n",
    "                    done)\n",
    "        if done:\n",
    "            break\n",
    "        observation = next_observation\n",
    "    return episode_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均回合奖励 = 887 / 100 = 8.87\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD8CAYAAAB6paOMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xt8VPWd//HXJ1cgQAgkXAMCAioCcomACAIKKGJFrW5Rt9KqP1ov27pqXaxWu22pbC9rb65d2mV7WVu0VitVWoR67XZV0IKCNyKgIipBEbnIJfD5/TEnYZLMZCaZCZPMeT8fj3nkzPd855zvN5mczznf7/d8j7k7IiISXjmZLoCIiGSWAoGISMgpEIiIhJwCgYhIyCkQiIiEnAKBiEjIKRCIiIScAoGISMgpEIiIhFxepguQjNLSUu/fv3+miyEi0qY8//zz2929LFG+NhEI+vfvz+rVqzNdDBGRNsXM3kwmn5qGRERCToFARCTkFAhEREJOgUBEJOQUCEREQk6BQEQk5BQIRERCToEgSS9u+YiXtuwE4JMDh9j5yUEA3J0tO/by0d4DtXl3fnKQ9z/ex76Dh3jy9SrWvbOTxX/dFHO7O/ceZOnarXXSPjlwiI1Vu3l35ycA7NhzgP3Vh2r3t+3jfUeWd+2LW+adew+y9aNPOFB9OGH9arZZ3x/XbuWhNe/wwls7atN27TvInv3VdfJ9tPcA+w5Gfi+bt+9psB1359fPvMlTr1exevOHvP/xPnbvr2b3/urI5/YeZNFTb/DxvoMJy1rj8GHnK79by7p3dibM+/SGqpjlqvHJgUPc//wWVr78ftw8H+45wMMv1v1b7dp3kIfWvMPBQ4f5YPd+tu/ez1OvV/Hnde+y7+AhHn91G8+/eeR39/7H+/jj2q1UbtvF4cMNHxP7+Kvb+OuG7fytcjt/Xvcu33r4Ze5b9TY790b2U7ltNwCHDjv3rXqb6kOHqdq1n1/+bTNPvl7FWx/sjVn2HXsif5/GbNq+h68++BKvvPtxbdqe/dU88MKWOt/5ZGzbta9B/aK/Yx/s3s+S597i28teYWPVbnbV+7u7O+9/vC/mfv+87r1Gv/eJbHh/F/sOHmLfwUNUbtvFOx99wicHIsuNefODPTy6/r06/+uxbPt4H+7O4cMN/z+j/3+37drH3gPVLHrqDQ5UH2bnJwd584P439GWZG3hmcUVFRWeiRvK/v7WDkaUdyE3x+g//xEANi+cxcR/e4wtOz7hnivHsX7rTr697NWktnf28J7cOmsoL7y1gz+te49HXny3zvqpx5Vx6qBSvvXIK00u68n9S1i1OXLAObasiDeqkvtCnTakjKder0p6P186YzA/+ssGAK6Zeix3Pf4Gs0f25qE1WxN8svle+Np0bnnwJf607r24eaYcV8YTr0XqccGoPjzw93ca3eZVU45lyXNvsWNv/IPbxWP7UdqxgHXv7KS4fT5/COq47EuTuOTnz/BRI59tjtvOGco3Hn455e1sWDCTry9dzz3PvtVg3cRBpfy1cnvt+3jflVMGduP/Nn7QIH3msJ5x/w7D+xRz+vHd+WHw/QD41nnDGNm3C+f8+K/NqUqtAaVFbIoTyLt3KuT+L06gIC+H8Xf8JWaewrwcLh7bj1/8bXOj+xnZtwtr3v4oYXmiv2OTBpfymZP7cu1v/p7wc/FEf38Bykvas2VH5ETw11eMZdLghDcHx2Rmz7t7RcJ8mQoEZnYW8EMgF/i5uy+MlzcTgeD5N3fw6bv/xnXTBvO5Cf0Z+Y0VANx96WiuuueF2nyTBpfy9Ibt8TYjIpKyzQtnNetzyQaCjDQNmVkucBcwExgKXGxmQzNRlvr+uHYrp3/vCbZ+FInGP1i5oTYIAHWCAKAgICJtXqbmGhoLVLr7RgAzWwLMBlK/Lk7RDb9by4Hqw0m1q4uItLTHb5zS4vvIVGdxH+DtqPdbgrTMC1rKzDJbDGnbph4Xu033ha9Nb5H9TTi2W6PrH/6nibXLmxfOitnU0C6/7uHg1lknsHnhrJjbvv1TQ9m8cBb3f/GUOunLrzutzvvpQ3vULp89vGeD7USX48cXj4pb/geunsCfr5vEmtums/a2GbXpt50zlPEDu7L2thmsvW1Gk5pQ8nMb/pNvXjirTplrfPXs42uXH7rmVJ6/dRqTh5Tx9E1T6+Q768S6dbz70tFx9/+5Cf3r1GXTHWfXLl83bTCb7jibzQtnMaC0KHFlUpSpQBDrMFuns8LM5pnZajNbXVWVfGdmulTHGNEhrdN3LxzRIO1zE/o3eTsrrz+NP147MXFG4MtnDG50/fcuOql2OfqA07WogJ9ccuSAN6xPZ5796hncOusEOrfLY+X1kxPu++mbpvKTS0Zx1yVHDjI1nYkn9Opcm3bB6CPnVsP6FDOivJhvzD4x5jZPG1LGry4fVyfNgrOhH8wZyfyZxzN5yJHgNrxPMQAV/bvWpv37P5zEcT071b7/2jlD+dllFVwwKlKO0+p1ePbv1qHO+8K8I4ejZ796Bk/fNJWT+5cAMLpfCcf37EyXDgUUd8ivzXf5xAEsmXcKxR3y66QD9OnSHoBPjy5nWJ/OlJe0r1037YQe/PjiugfpmqD2s8sq2LxwFt2KCmrXzTvtWG6ddQIzh/VkRHkx3ToW8svLx9K365E6fOXM4/jmecPqbHPm8F7EU9w+UuZH//k0Fl4wHDNj/b+eyaP/fBrXTRtS+/s/GjLVNLQF6Bv1vhyoM+zE3RcBiyDSWXz0ihZx0/0vHu1dhsJ10wbzg5UbEmdsglNinLHWHwSRl2O1wT3eCJRB3Y8cxApyczhwqGHzYLv8HJbMO4WRfbvwu9Vvs3Vnw2GM9c9KNyw4mwl3/KU27zkjenPOiN518lw5aSBXThoYr4ps/PbZTP7e43Rul0/frh1qD0CzRsyiatd+yjoV0rdreyYPKWP5+vd5aM073HTm8UwcVMr23fsBWNpIkMvLMUb168K5J/Vmx94DPL1hOznBcah7p3Z8cfKxfHHysQC1+6uv5u/wx2snkpMDJ/aOBItvnDeMqcd351Mn9Wb+Ay8BkQC4eO7JAPzfzafz2nu7aptje3QupEfndgD8z5Xj2Hew4d/hr/8yNe4oohp/uWEy+w4eoqgwj/zcHA4fdnbsPcCTr1dxzojeFOTlsHnhLHbsOUDn9vnk5tQ98P7m/43nht+t4fZPRYJn5G/UcD/3XDmOpzZUcc3UQUDkqm/brn3s2R8ZrvvYDZPZvvsAl/zsGaoPO4s/V8Hlv1jNlOCqcUiPTgzpEfnuFRXm1S4fTZkKBKuAwWY2AHgHmANckqGyAPDTJ99g4qBSHF0JRCvMy2F/gv6S3BzjUJwrqH5dO9CruB3PbvoQgOumDUk6EAzp0ZHX398dc93TN01l0nceByDHjDkn92XJqiOtjV06FPDdC0cwqHtHqg87u/dX8/n/XsWiz45hxok9a4cD16g5ewRYef1kSjrkM+ZbKynpkF9niOlzt0yjc7vImedjN07htfd2cVLfLgBs3r6HDgW5tXmL2+fXjoF/5EuTag/IyRhzTAm/v2pCbTlzcoynbzo9Zt6ag3JNcLlwTDkXjikH4ILR5UntL8eM/NwcfnTxKG57aF0QCGKfkcYKAl2LCuhVHPkdDi8vrrOuY2Eenzqpd1DGXhTm5fL9fzhyxdSruD29itvz6PrIsNThfbrUrivMy6UwL5f6yks6UF7SoUE6wJ2fOYnykg60y8+lXf6Rz+bkGN06Fjb4nZREnflHO65nJx7+pxhH/npOHVTKqYNKa993LSqga9Q2B5Z1ZGAZPPPVM9j5yUGOLevY7FFALSUjgcDdq83sWmA5keGji919fSbKUmPhn5K7FyBspg3twZfPGMyMO5+Km2fMMSU8Fxzob511Qp37IEb360LXosLaQACRMfx3P/FGnW20y89h/MBudcZSF0Q1Fdz5mZM444QejPj6owD07dqB6UN7sOLl92mXn1vbJPKP4/sxrHcxF4wur/N5iJx51hyscixyxvpScDPaiuuPtG0P6t4RgLW3zyAvxzjx9uVAZGx9TRCIlDm3NggA9K/Xlvu3+afXXoWUFBXEPeDU98LXptcJKEfDeaOOXKFcPWUQm7bv4byRyXXbPX/rNDq2S+5Q8pNL4reZ1/wNzx3ZO26eZJw/Krngd7SVdiyktGPDINoaZOzOYndf5u5D3P1Yd1+QqXJIQ9+Kauc0IpeujZ3B5AZnjkN7dW7QvNGhMI/5M4+vk3bD9CH89v+Nr7PNV785k198fmydtLsvHVO7fP6o8joHYYAffGYkD1w9ga5FBbVNQblmzBnbr0EQAGqDAMDGO2bxwNUTjpSzoOGBrLh9PkWFR9IXXTamQZ7GFBXmUdw+P3HGeroWFdSeyQ7r0zlB7tRtXjirTlNVz+J2/PqKcQ3a3OPp1rEw5ll7U/Xt2oHNC2dx7kmpBQJpujbxqEpJn4UXDGfTB3v4zyc3xs3TtagAM3CPtJU3ZmTfLlw+cUDMu1AB8nOswUE5LzcnZrt+tEHdO9bpiIulqDCP0f0inYk1DVNN6WCL1/QRT6d2TT+op+p3X5jArv3pvYNZpD7NNRQyc8b2S5jHgNe+OZPPTejPrefEv89v4qBS/nDNqfQqblcn/TdXjuPYskgzSXNGPjxw9QTu+0JkBMdjN0zmsRvqjqQ5ZWDDIDLluO4AnDcq+VHINSWbdkL3JpfxaGlfkEv3Tu0SZxRJga4IwiiJ/vCCvBy+fm7soYbQ+C3vEwaVcum4Y5o9b07NWT5EOtqi/e/80+naoWFb+4DSoiZ3wEU6YKfG7PyMdsGoPnX6ArLFwNIiNiYYeSPhEPpAsL/6ED99In4zSTZKFAfSMXy5pcZeRY/uSYdEzU8A//6ZkWndZ2vx6D+fhm6XEQhh01DVrv387Y0j8wP9/OlN3Lny9QyW6OirP8b+9OO7x70TNpZbZ52QdF7dod165eXmxOxUl/AJ1RXB4cPOyQtWAvDkV6bQvVO7hHO0Z6P6E87OndCfyUPKosbWN370vmLigDrv84I7Z6OHENYPNueM6MXjr25rsK3Hb5zCx02Y515E0i9UgWDZuiPz/0/+7hNMGlzKyCxs+43nkS9F7iyN1xrQp0t73glmXY1l+XWnsfbtjxp0AB/XoxM3zzye82N01FoQVOKNHz8a86iISONCFQjq36r+9Ibt7EjwtKFs8furTqm95T/eIyju/cJ47lj2au2t7/Ud17NTnblkapgZXwimHxCRtif0DYTr3vk4caY2pv5dqd2KChhzzJHJweJNo1Fe0oG7Lh1d57b85nLN4irSZoQqELSFx3Kmw/yZx3PKwG61sy3eWW/Uy9H4NYwPxvqfcXzrHaMvIhGhCQTVhw7zneWvZboYR4WZ8dt542snvqo/zUH14ZZ/6M7w8uLIXPZRk3GJSOsUmkDw2KvbqNqV/OyPbcWNM4bEvZFqwXnDGT+wK8f3qtuuX30oHFdGIpKc0ASCeNMkt1aP3zglqZkKhwUPCIlui+8QtPEPLy9mybxTGkwIdlFFZHbGfkncTCUi2S80gaCtGVBaxMxhDR/tF8+ar83gr/8ylVtnnZBwvp0xx3Rl88JZHNNNgUBEQjZ8tK1JZsRNzXVOzaP6GnvKlYhILLoiEBEJOQWCti6Fro+OwUNX8nM02F8kzELTNNS2uoojWnq8/7fPH86wPsUJHxIjItlNVwStWP3O4l9fMbZBnnh3CSejpKiAa6YOatbDY0QkeygQtGITBpXWuUdg0uDkp4oWEUlWSoHAzC4ys/VmdtjMKuqtu9nMKs3sNTM7Myr9rCCt0szmp7J/OTrTRYhIdkv1imAdcAHwVHSimQ0F5gAnAmcB/2FmuWaWC9wFzASGAhcHeSVJ//nZMZkugohkmZQ6i939FYj5gPLZwBJ33w9sMrNKoKaBu9LdNwafWxLkbd7DbUPiN1eOo3MwX9Dp9SZx0xWBiKSqpUYN9QGeiXq/JUgDeLte+rgWKkPWiJ64rX7I7VOS3mf4ikj4JAwEZrYSiDXXwS3u/lC8j8VIc2I3RcU8pzWzecA8gH79+iUqZpuWn2sUFebx0d7Ej2yMvvp67IbJDCzr2JJFE5EQSBgI3H1aM7a7Begb9b4c2Bosx0uvv99FwCKAioqKlBtAWnMTimGcOqiUR158N4m8RygIiEg6tNTw0aXAHDMrNLMBwGDgOWAVMNjMBphZAZEO5aUtVIY25fsXncRjN0xOmE9D/kUk3VLqIzCz84EfA2XAI2a2xt3PdPf1ZnYfkU7gauAadz8UfOZaYDmQCyx29/Up1SDpsh6NvTSTQbv83KTO8HXzl4ikW6qjhh4EHoyzbgGwIEb6MmBZKvvNNjq0i0gm6c7iVkAn+SKSSaEJBK25s7g5mvLQGhGRxoRm9tHW7IqJA5qU//lbp9GpXX7ijCIiSVAgaAW+cubxTcrfLYlnGYuIJCs0TUMiIhKbAoGISMgpEKTJqH5dmpR/7ICuLVQSEZGmUSBIk0vGNm0+pF98/uQWKomISNMoELSwc0b0ynQRREQapUDQwnIauVtsYFnRUSyJiEhsGj6aJn26NO25AIbxh2tO5YPdB+qkf++ikxjepzidRRMRaVRoAoHHfuxB2hTmN/3iqnO7fDrXuzHswjHl6SqSiEhS1DQkIhJyCgQiIiGnQCAiEnKhCQSWoVn/4/VMaOppEWktQhMIWrqzWESkrQpNIMiU6BP/AaW6b0BEWh8FgrSJ3dYTfR3yX3MryM9Vm5CItC4KBGnTeNPTD+eMZGBZx6x7UpqItH0pBQIz+66ZvWpmL5rZg2bWJWrdzWZWaWavmdmZUelnBWmVZjY/lf2LiEjqUr0iWAEMc/cRwOvAzQBmNhSYA5wInAX8h5nlmlkucBcwExgKXBzkbdP+d/7pmS6CiEizpRQI3P1Rd68O3j4D1MyPMBtY4u773X0TUAmMDV6V7r7R3Q8AS4K8LWZ/9SH+/tYO7lj2aovtoynzDGnYqIi0Numca+hy4N5guQ+RwFBjS5AG8Ha99HGxNmZm84B5AP36NW2u/2jX37uWR156t9mfT17DI3ysg776CESktUkYCMxsJdAzxqpb3P2hIM8tQDVwT83HYuR3Yl+BxDw0uvsiYBFARUVFsw+fT7y2rbkfbaKGRZx32kA+O/4Ydu87yPShPeqs05WBiLQWCQOBu09rbL2ZzQXOAc5wrz3f3QL0jcpWDmwNluOlZ52bZ54AwH9/fmyGSyIiEl+qo4bOAv4FONfd90atWgrMMbNCMxsADAaeA1YBg81sgJkVEOlQXppKGUREJDWp9hH8BCgEVlikreMZd/+iu683s/uAl4k0GV3j7ocAzOxaYDmQCyx29/UplqGVSK6tR01CItLapBQI3H1QI+sWAAtipC8DlqWy37ZMncUi0trozuIMydRsqCIi9SkQpI1O9UWkbVIgEBEJOQWCtFFTj4i0TQoEGaLRQyLSWigQiIiEnAJB2qizWETaJgWCDNH9BCLSWigQpI0a/UWkbcr6QLDnwKFMFyEmdRaLSGuRzucRSODScf2YOaxXposhIpIUBYI0KetYWLu84PzhGSyJiEjTZH3TUDpdPLZv3HX9unVg5fWTqVwwM6ltqbNYRFoLXRE0QZcOBY2uH9S941EqiYhI+uiKoAnS2b+rzmIRaS0UCJpAB28RyUYKBE2gZwiISDZSIGiCdFwRjD6mJLKt1DclIpIW6ixugnQcvP9rbgWbt+8lL1cxWERaBx2NjrJO7fIZXl6c6WKIiNRKKRCY2TfN7EUzW2Nmj5pZ7yDdzOxHZlYZrB8d9Zm5ZrYheM1NtQIiIpKaVK8IvuvuI9x9JPAwcFuQPhMYHLzmAXcDmFlX4HZgHDAWuN3MSlIsw9GjYUMikoVSCgTu/nHU2yKOTMo/G/iVRzwDdDGzXsCZwAp3/9DddwArgLNSKYOIiKQm5c5iM1sAXAbsBKYGyX2At6OybQnS4qXH2u48IlcT9OvXL9VipmxUvy4U5OqKQESyT8IrAjNbaWbrYrxmA7j7Le7eF7gHuLbmYzE25Y2kN0x0X+TuFe5eUVZWllxtWtCDV5+KqWlIRLJQwisCd5+W5LZ+AzxCpA9gCxA9Q1s5sDVIn1Iv/Ykkt59xpR0bn2tIRKQtSnXU0OCot+cCrwbLS4HLgtFD44Gd7v4usByYYWYlQSfxjCCtTbhoTF9+OGdkposhIpJWqY4aWhg0E71I5KD+5SB9GbARqAR+BlwN4O4fAt8EVgWvbwRpbUJOjjF7ZMwuDRGRNiulzmJ3/3ScdAeuibNuMbA4lf2KiEj66M5iEZGQUyAQEQk5BQIRkZDT7KPN0L9bBzZ/sJf/uWIcxe3zM10cEZGUKBA0w++vmsDmD/Yy5pi2M02SiEg8CgTN0K1jId06Fma6GCIiaaE+AhGRkFMgEBEJOQUCEZGQUyAQEQk5BQIRkZBTIBARCTkFAhGRkFMgEBEJOQUCEZGQUyAQEQk5BQIRkZBTIBARCTkFgiR8dvwxmS6CiEiLUSBIwrzTBma6CCIiLSYtgcDMbjQzN7PS4L2Z2Y/MrNLMXjSz0VF555rZhuA1Nx37FxGR5kv5eQRm1heYDrwVlTwTGBy8xgF3A+PMrCtwO1ABOPC8mS119x2plkNERJonHVcEdwI3ETmw15gN/MojngG6mFkv4Exghbt/GBz8VwBnpaEMIiLSTCkFAjM7F3jH3dfWW9UHeDvq/ZYgLV56rG3PM7PVZra6qqoqlWKKiEgjEjYNmdlKoGeMVbcAXwVmxPpYjDRvJL1hovsiYBFARUVFzDwiIpK6hIHA3afFSjez4cAAYK2ZAZQDL5jZWCJn+n2jspcDW4P0KfXSn2hGuUVEJE2a3TTk7i+5e3d37+/u/Ykc5Ee7+3vAUuCyYPTQeGCnu78LLAdmmFmJmZUQuZpYnno1RESkuVIeNRTHMuBsoBLYC3wewN0/NLNvAquCfN9w9w9bqAwiIpKEtAWC4KqgZtmBa+LkWwwsTtd+RUQkNbqzWEQk5BQIRERCToFARCTkFAhEREJOgUBEJOQUCEREQk6BIAl5ubFmxhARyQ6hCgQ/nDOyyZ/5/kUn0au4fQuURkSkdQhVIJg9MuZEp4369JjyFiiJiEjrEapAkIxR/bpkuggiIkeVAkE9108fkukiiIgcVQoE9VjMRyaIiGQvBQIRkZBTIBARCbnQBYKendtluggiIq1K6ALBn748qXb5uVvO4PEbp3D/F0/hlIHdMlgqEZHMaaknlLVaJUUFtcvdO7WDTjCgtAhTH7GIhFTorggSUUAQkbBRIBARCTkFgnrcM10CEZGjK6VAYGZfN7N3zGxN8Do7at3NZlZpZq+Z2ZlR6WcFaZVmNj+V/YuISOrS0Vl8p7t/LzrBzIYCc4ATgd7ASjOrmbvhLmA6sAVYZWZL3f3lNJQjLdRHICJh01JNQ7OBJe6+3903AZXA2OBV6e4b3f0AsCTI26p0KMjNdBFERI6adASCa83sRTNbbGYlQVof4O2oPFuCtHjpDZjZPDNbbWarq6qq0lDM5P3lhslHdX8iIpmUMBCY2UozWxfjNRu4GzgWGAm8C3y/5mMxNuWNpDdMdF/k7hXuXlFWVpZUZdKlV3F7vn3+cI7r0emo7ldEJBMS9hG4+7RkNmRmPwMeDt5uAfpGrS4HtgbL8dIz6tvnD+e7j77Gyf27AnDJuH5cMq5fhkslItLyUh011Cvq7fnAumB5KTDHzArNbAAwGHgOWAUMNrMBZlZApEN5aSplSJf+pUXcdcloCvI0olZEwiXVUUPfMbORRJp3NgNfAHD39WZ2H/AyUA1c4+6HAMzsWmA5kAssdvf1KZZBRERSkFIgcPfPNrJuAbAgRvoyYFkq+xURkfRRO4iISMiFbvZRgEHdOzKqrx5SLyICIQ0EK6/XfQIiIjXUNCQiEnIKBCIiIadAICIScgoEIiIhp0AgIhJyCgQiIiGnQCAiEnJZHQgOH9YDiEVEEsnqQPDBngOZLoKISKuX1YEgP1cPIBYRSSSrA0FujgKBiEgiWR0I2ufrIfQiIolkdSDIy83q6omIpIWOlCIiIadAICIScgoEIiIhp0AgIhJyKQcCM/snM3vNzNab2Xei0m82s8pg3ZlR6WcFaZVmNj/V/YuISGpSelSlmU0FZgMj3H2/mXUP0ocCc4ATgd7ASjMbEnzsLmA6sAVYZWZL3f3lVMohIiLNl+ozi68CFrr7fgB33xakzwaWBOmbzKwSGBusq3T3jQBmtiTIq0AgIpIhqTYNDQEmmdmzZvakmZ0cpPcB3o7KtyVIi5fegJnNM7PVZra6qqoqxWKKiEg8Ca8IzGwl0DPGqluCz5cA44GTgfvMbCAQa24HJ3bgiTlFqLsvAhYBVFRUaBpREZEWkjAQuPu0eOvM7CrgAXd34DkzOwyUEjnT7xuVtRzYGizHSxcRkQxItWnoD8DpAEFncAGwHVgKzDGzQjMbAAwGngNWAYPNbICZFRDpUF6aYhlERCQFqXYWLwYWm9k64AAwN7g6WG9m9xHpBK4GrnH3QwBmdi2wHMgFFrv7+hTLICIiKUgpELj7AeAf46xbACyIkb4MWJbKfkVEJH10Z7GISMgpEIiIhJwCgYhIyCkQiIiEnAKBiEjIKRCIiIRcaALB9y46KdNFEBFplUITCC4cU57pIoiItEqhCQQiIhKbAoGISMgpEIiIhJwCgYhIyCkQiIiEnAKBiEjIpfo8glbvkS9N5LlNH2a6GCIirVbWB4ITexdzYu/iTBdDRKTVUtOQiEjIKRCIiIScAoGISMgpEIiIhFxKgcDM7jWzNcFrs5mtiVp3s5lVmtlrZnZmVPpZQVqlmc1PZf8iIpK6lEYNuftnapbN7PvAzmB5KDAHOBHoDaw0syFB1ruA6cAWYJWZLXX3l1Mph4iINF9aho+amQH/AJweJM0Glrj7fmCTmVUCY4N1le6+MfjckiCvAoGISIakq49gEvC+u28I3vcB3o5avyVIi5fegJnNM7PVZra6qqoqTcUUEZH6El4RmNlKoGeMVbe4+0PB8sXAb6M/FiO/EzvweKz9uvsiYFFQhiozezNRWRtRCmxP4fNtUdgyk6KCAAAEuklEQVTqHLb6guocFqnU+ZhkMiUMBO4+rbH1ZpYHXACMiUreAvSNel8ObA2W46U3VoayRHkSlHG1u1ekso22Jmx1Dlt9QXUOi6NR53Q0DU0DXnX3LVFpS4E5ZlZoZgOAwcBzwCpgsJkNMLMCIh3KS9NQBhERaaZ0dBbPoW6zEO6+3szuI9IJXA1c4+6HAMzsWmA5kAssdvf1aSiDiIg0U8qBwN0/Fyd9AbAgRvoyYFmq+22iRUd5f61B2OoctvqC6hwWLV5nc4/ZVysiIiGhKSZEREIuqwNBNk1nYWaLzWybma2LSutqZivMbEPwsyRINzP7UVDvF81sdNRn5gb5N5jZ3EzUJVlm1tfMHjezV8xsvZl9OUjP2nqbWTsze87M1gZ1/tcgfYCZPRuU/95gsAXBgIx7gzo/a2b9o7YVc5qX1sjMcs3s72b2cPA+2+u72cxeCqbnWR2kZe577e5Z+SLSGf0GMBAoANYCQzNdrhTqcxowGlgXlfYdYH6wPB/4t2D5bOBPRO7nGA88G6R3BTYGP0uC5ZJM162ROvcCRgfLnYDXgaHZXO+g7B2D5Xzg2aAu9wFzgvSfAlcFy1cDPw2W5wD3BstDg+98ITAg+F/IzXT9Gqn39cBvgIeD99le381Aab20jH2vs/mKYCzBdBbufgComc6iTXL3p4D6z9ycDfwyWP4lcF5U+q884hmgi5n1As4EVrj7h+6+A1gBnNXypW8ed3/X3V8IlncBrxC5Ez1r6x2UfXfwNj94OZHpW+4P0uvXueZ3cT9whpkZUdO8uPsmIHqal1bFzMqBWcDPg/dGFte3ERn7XmdzIEh6Oos2rIe7vwuRgybQPUhPeYqP1iZoAhhF5Aw5q+sdNJOsAbYR+ed+A/jI3auDLNHlr61bsH4n0I22VecfADcBh4P33cju+kIkuD9qZs+b2bwgLWPf62x+ZnG8aS7CIF7d2+TvxMw6Ar8HrnP3jyMngLGzxkhrc/X2yD03I82sC/AgcEKsbMHPNl1nMzsH2Obuz5vZlJrkGFmzor5RTnX3rWbWHVhhZq82krfF65zNVwSNTXORLd4PLhEJfm4L0uPVvc39Tswsn0gQuMfdHwiSs77eAO7+EfAEkXbhLhaZzgXqlr+2bsH6YiJNiG2lzqcC55rZZiLNt6cTuULI1voC4O5bg5/biAT7sWTwe53NgSAM01ksBWpGCswFHopKvywYbTAe2Blcai4HZphZSTAiYUaQ1ioFbb//Bbzi7v8etSpr621mZcGVAGbWnsgULq8AjwMXBtnq17nmd3Eh8JhHehLjTfPSqrj7ze5e7u79ifyPPubul5Kl9QUwsyIz61SzTOT7uI5Mfq8z3Xveki8ive2vE2ljvSXT5UmxLr8F3gUOEjkTuIJI2+hfgA3Bz65BXiPyAKA3gJeAiqjtXE6kI60S+Hym65WgzhOJXOq+CKwJXmdnc72BEcDfgzqvA24L0gcSObBVAr8DCoP0dsH7ymD9wKht3RL8Ll4DZma6bknUfQpHRg1lbX2Duq0NXutrjk2Z/F7rzmIRkZDL5qYhERFJggKBiEjIKRCIiIScAoGISMgpEIiIhJwCgYhIyCkQiIiEnAKBiEjI/X874I0kv/q67AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "agent = ExpectedSARSAAgent(env)\n",
    "\n",
    "# 训练\n",
    "episodes = 5000\n",
    "episode_rewards = []\n",
    "for episode in range(episodes):\n",
    "    episode_reward = play_qlearning(env, agent, train=True)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    \n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "# 测试\n",
    "agent.epsilon = 0. # 取消探索\n",
    "\n",
    "episode_rewards = [play_qlearning(env, agent) for _ in range(100)]\n",
    "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),\n",
    "        len(episode_rewards), np.mean(episode_rewards)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Q 学习"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class QLearningAgent:\n",
    "    def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=.01):\n",
    "        self.gamma = gamma\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epsilon = epsilon\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        \n",
    "    def decide(self, state):\n",
    "        if np.random.uniform() > self.epsilon:\n",
    "            action = self.q[state].argmax()\n",
    "        else:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        return action\n",
    "    \n",
    "    def learn(self, state, action, reward, next_state, done):\n",
    "        u = reward + self.gamma * self.q[next_state].max() * (1. - done)\n",
    "        td_error = u - self.q[state, action]\n",
    "        self.q[state, action] += self.learning_rate * td_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均回合奖励 = 855 / 100 = 8.55\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD8CAYAAAB6paOMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xt8FPW9//HXJ0tuBAKBQIAESIAgBgqIkYuIonJH5bS1Hqyt1Evxbo9WLZbaWq2Vn72deo61pS3nHG091Lb2SC2KWGm1rYpYb6CiEVAQFRC5CyHh+/tjZ8MmO5sLu8kmO+/n47EPdr7znZ3vd9nMZ76XmTHnHCIiElwZqS6AiIiklgKBiEjAKRCIiAScAoGISMApEIiIBJwCgYhIwCkQiIgEnAKBiEjAKRCIiARcp1QXoDkKCwtdaWlpqoshItKhvPDCCzucc72aytchAkFpaSlr1qxJdTFERDoUM3unOfnUNSQiEnAKBCIiAadAICIScAoEIiIBp0AgIhJwCgQiIgGnQCAiEnAKBEn00b5DHKiu4ZPqWj7ad4iXNu/iqTe318tz8HAt+w/VxGz70uZdrH1vNx/tO8Sjr75P7RHHrgPVcfdVU3uk0fXVNUfY/cnhmPRHXtnqu93DL73HnoOHue+ZTSx+6m2qa47E/eyd+6vZf6iGg4drcc4x/741vP7+HvYePMzeg/X3+drWPTy+7gN27g/vc8/Bw1TXHKH2iOPj/dVx67H7wGFqav3LsO9QDdv2HGT73kPUHgk/anXbnoNUbdvLH1/eyp6DsfWG8Hf24JrNHDni6v6fouu/71ANL7yzk39U7WDzzgOsWr8t7nfgZ9uegzHf+WNrP+DND/fWlXPvwcMcqqnl4/3V1B5xvLplNy9t3hXzWZ9Uh38nke8p4onXPuSD3QfZtGM/h2pqY7aL9//+0b5DOOf4aN8h330dqK6ply/iQHVN3boI5xy//NtG399xTe0Rdh/w33+0fYdq2H+ohre37+Nvb+3g7e37eHv7Po4c8S9jtJ37q9kRJ88+73cZvd/9h+r/X0dEvqu3t+/jH2/vAMJ/n4+tfZ+XNu9i14Hqut/gR/sOsXnngZjvIvL/5GfV+m28+eHeevuO/OYjnHP8/oUtMZ8LsPK1D/lwz8F4X0NSWUd4ZnFlZaVrjxeUvb/7E3YdOExJQS6r1m/n2v99MW7e+y8Zyw8ef9P3j35I7y5UbdtXL21Aj868u/MAI4rzuWryEPYequGm371CcfdcPjummLufrIr5nNzMEF+eVEZ5UVeu8cqSnxO+ZnDPwfo/tFvPruDWP75GZsg4XOv/G/jSyaU45whlZLDk7xsBOHVor5jgFs9nxhTz0D/fq1u+bspQfvTEm755f3FhJZfeF/t//Knibrz63m4AZn+qL/265/Dzpzc2a/8PXDqOb//xNbbtPchZI/tx/7P+19Zce8YQ3++zoV5ds9m+N/YA1L9HLpt3flK3/PVZw/jN85vZe7CGbT75/fz0Cyfy3MaP2LB9P6cMKeSO5a/XW3/5aYO5eGIpY7/755htbzmrgqL8bK5+IP7vrzm+MH4Av3r23WPaNvJ7ijh7VD/+sn4bew/GHuAa5o3n3BNL6F/QOe5vBuCk0gKcg50Hqtmx91DM79xP986Z7PIJVskwqbyQp9/a0WieUSXd6J2fw4vv7qoX0KYPL2LFug9j8l8/dSjXnll+TOUxsxecc5VN5ktVIDCzGcCPgRDwC+fconh5UxEIfvH0Bn7617dZ842p9dI/3l/NCbev5N4LxnDFr//ZpmUSkWDatGj2MW3X3ECQkq4hMwsB9wAzgQrgfDOrSEVZ4vnOn15nx75q/lG1g3lLVnPEa9q/4p2dKgiISLpI1b2GxgJVzrkNAGa2FJgDNN1ebGOf/8VzAFz9v/9k8nG9KcrPSXGJRESSK1WDxcXA5qjlLV5am9u4Yz8r1n3QZL7lr37ATb97pQ1KJG1hyvFFSf282+cMr7c8rE/XFn/Gc18/kw3fncXb350Vs+7N78yMu92fv3pas/ex8c7Yzwb474tOqrc8c0SfuJ+x/jsz2HjnLB668mRevGUqy6+dxFWnD+aWsxpv1H/lzPK6MSuAKyYPjslz0cTSmLSFs44H4K5zR8b97EnlhXHXXXX6YH72xRMbLRtAKMPYeOcs3rh9BvdfMpb7Lh7LP2+ZGpPv9OOO3syza1R9irvnsvHOWWxaNJsrJg+mrDCv3naTj+vFD88bVbf8Xw2+819fOq7u/ZVR383qhWc2WfZEpSoQmE9avcEKM5tvZmvMbM327c0bnGyuyd9bxf3PbALg9O//hcvuf6He+l8/F/+Gfa9t3ZPUskhyfXXqUO46dySbFs2OezD73eUT+GYTB62mlBTk8tWpQ3ng0nHce8EYcjJDdetuPbui3oHpqRtPj9n+7wvOqLdc2CWbovwcMjKMUEb9P49BhXlkdYr/pzq4VxeWXzuJN26fQbfcTL44fiCbFs2ue/XvkUtuZojR/btjZmxaNLuu/t1yM+mRl8UpQwr58dzRnDOqHwBdsut3FvxyXrib+d4LxpDdKYSZMWZAAQV5WVT0y+fG6cP4wvgBvuU7ZUghmxbN5rqpQ3nl1ulcP3UoABkGq26YzNyT+gPw7M1nMv/UQfW2Pam0gC+fOohNi2ZzXmX/uvQzh/Wmd9dsAK45YwhLvlT/oBrtxunDmD48fmCLmDGiD2ZGTmaISeW9OHVoL3rkZcXk69+jc9zPMAv/331txjBW3TC5Lr1ftxz++6KxfGZMCdeeMYTR/btz+nG96207ccjR38xNM4bVve/dtfV7IVLVNbQF6B+1XAJsjc7gnFsMLIbwYHEyd77powPc8vA6vjih1Hf9wj+sjbut35S9oPrG7OP57ZotrP9wb4u2G9izM+98dCDh/edlhdjfYFrglacPqTuQXj91KI+ujW3tmcGAnp1Zfu0kZt39NABF+dl8uOcQy6+dxNvb99XNugL4zfzx/OviZ+t9xg/PG83Ysh51y49F7edLE8uoqT3ChRNK6w4anx1TQl52iPueeYeCzpkUd89lbFkPVm/c6TsQ+OO5o/nK0pcA6Nu9/oHg2+cM51vL1nHbnOHM8A5wFf3yAXj5W9NiPuvpm86ISTvsTWE8r7KEhbPDQWHO6GJOHlzIS5t3cfnkwfzj7Y94b1d4NtSZxxfx1xsnM7BnXsxnRWR3CrFw1vHcsfx1fnLBGBY9+gbv7jzA/ZeMrZdvzIACACpLe1BWmMd3/mUEV50+hD7dcupmZXXN6cTegzUsmDms3rbH983n9ff3cP20oQzsmcfBw7UUdgkHhJXXncqTb2yjoHMW+bmZXP6r+id4qxeeyTUPvMhzG3fy28sncLj2CJ//ebjr9x8LzqBnl9iDPsAj15zCGx/s5YbfvgzAhEE9ue+Z2JPFs0b1jUn7/RUn89l7/1EXIACun3Yc1087Djg6K+7G6cfFbPvzCysp7p7rW6ZkS1UgeB4oN7My4D1gLvD5FJWlRf79ibdSXYR249JJg9i5v7rFgaDh2Wa0Z28+k/F3xk6R9JPZKQOiAkFx99x6Z9PlRV25YdpQvv+4//TDSN7y3l146MqTWbV+OxX98qnol09xQfgsuk9+DgVRZ4Ub75zFrgOH66VBeOpftE6hjHpnjj84bxTb9h7kvmfeqdvvry4ZR3WcayXmjC7mjGG9WbV+O5O8M8W/3jiZLR9/wsQhhZwzql9MGVriX0/qz9+qdvDlSfXPwHt1zeapm8ItmOXXTuKapS/yfa9LprEgEHHppDLOPbGEgrwsphxfxKGa2noHQYBTygt58ZapdeWP/q56dc3mp18Yw7iynr71e/Cy8axav53h/boB9X9L5UVdKS8Kd8kdOeI4e1Q/5k0YWLe+d9ccfnXpOA4erqVrTiYAb9w+A+cgNytEPCOKuzGiuBvnnljCx/urKcjL4nvnjmTjjv11U5KfuvF0SgpiD9qRVktGnAbdA18ex6r12+taYtEnBVMrktt92ZiUBALnXI2ZXQ2sIDx9dIlzbl0qyiJHLfrMp1jw0Kst2sb8Ovni6JRh1ByJ37h7/LpT6dMtthm86obJ/PJvG/jVs++y7tvTGf6tFeF9N6MsF00siwkEpd4BLXJArnWOrjmZdX+McPSsNeIXF1by3q5PMDPfA5SZseG7s2is6XrEO+ZneAXN6pTRaJdPwzIN7JlXdzBOJAgAdO+cxf2XjGs0T7fOmdx38dhG8zQU/f00Vr/Gyj9jROyZdUTD7ySejAzjP84/ISY9M5RBZuhomaK79JojUu7Ped1Ub3ywlyff2EZh1ywyMmJ/gCUFuVw8sYzPj+sfsw6aX5/WlrIri51zy51zQ51zg51zd6SqHBL2pZNL686yor3wjSkJfe5/XXQS2Q0OBg0P2H/+6mms+cYUhhb5D7CGuw8+xaZFs8nL7lR35nXiwB718vX0ObjkZXeqd5a1adFsenpdCZFyRc4OGzOlooh5J5c2msevfz9apN6R/UvH95+fP4FHrjmFzln+59RmxjfPrmBI75ZPHmhLHeJRlW1l78HDdX2iQXPCgO4M6d0lJt1vsCya+Yz7jx/Ug2c37ATgtPJeRK5ZvOvckfzHk1V0zgz/7C45pYyRJd0Y3Ct2vz3zsnjyq5N9x2Qe/cokbl32Gt87dySDvr4cCF/NO2d0/IlnF04YyIEG4wn9e3Tm2+cMb9ZAYjIU5edw+5zhTK1om/1J6+uc1YkRxbEnUB2NAkGUL/xyNS/73AIiKPz6SRv27zZlbFkPls6fQOmCPwHhs+Sfz6vkv/6+kU+fUMxnxpRw3s+eAcJTOCcM7hnzGT/74omMKO5Gt86ZQOzZetecTH4QNQ0PYP6psVMRo902Z4RvelNn+ckWb4KCSCopEHh+8/y7gQ4CDc0/dVDdNL9okdkcEc2JE6cN7cVpQ4/OvY50pMfrRWmrM3QRCdPdRz1f+33LBknTXeeskO9AWkYLWwh+jnh9RX6Da8fiU2nQNBdJJbUIAmJoURfe/HBf0xk9lzaYVhjRJz+n3m2Oj+VQHgkEyQgDf7lhMoVdNfgqkgi1CAKiom8+mxbN5mlvjvjEIbF98xCe4gnEzPQB+OF5o7ivwcVBkYNw5HYB06LmPk+LMw86MsUyCY0LSgvzGr0uQUSapr+gNDCqf3ff8Y3PnVjCn9/YVvdQGAjPlHnttulkhjK467E3+J9n3qn3EJqLTylj8VMbCEUdpa8+fQjlRV18Z+VcMG4gXbI7MWd0MVdMHkxX76C87tvTfYMJwNE7nyena0hEEhPoFsGmHftTXYSk+L8rT/Z9cMX104Zyy1nHx6R3zupEZiiDhbMrYs7ab545jI13zqrXf3/D9OPiTs0MZRifGVNCKMPIz8msm2WUl92JTqE4gcD7NxktAhFJXKACwf3PbGLLx0fvcTP5+39JWVmSyczokn10YPeskX3ZtGg2fbu1/D4lZtbiKaMtFXkYUjIGnkUkcYEJBLsOVIdvNPfL1akuSqv47JiSuve9WjB42toHfT91s4YUB0TahcAEgsjDwxt74HtH1inqrlZfi7qFbXt8JHWkTH5XJYtI2wtMIEjFmW+b8qqXF2f+f3uq/3Dvlsndcpu+x4+ItD7NGupgunfOZNeBwzHp8Y7zA7zb+0YOvu3BbXNGMHfsAAb0jP+ADxFpO4FpEUS0w56SFll5nf9jCeOd71eW9uCxf5vEJaeUtV6hWignMxRzm2cRSZ3AtAjaT8dI62is62dYn/bTGhCR9idwLYKOLtld/WeNDD8EpCX36xncK8/33v8i0jEFpkWQDn48d3RM2mWn+t8TqLmmD+/j+8zcxvz5q5MT2qeItC+BaxH4DbR2FH5X93b0MQ8RSb3ABIJ2NHsyqVx7vFBARDqUwASCdNFUPGtP1wuISMcQmEAQfYfNjmT8oB6Nro80CPKyQlx+2mAevGxCG5RKRNJJYALB3J8/m+oiHJOzR/VrdH2Fd6GYmbFg5rC6ZRGR5kooEJjZ58xsnZkdMbPKButuNrMqM1tvZtOj0md4aVVmtiCR/bfEhu0d85bTmRn+/0UFnTNZed2pfPoE/9tDi4g0V6ItgrXAZ4CnohPNrAKYCwwHZgA/MbOQmYWAe4CZQAVwvpdX4mnQ5X/0Xv5GeVFXjQmISMISCgTOudedc+t9Vs0BljrnDjnnNgJVwFjvVeWc2+CcqwaWenkDbVBhXqqLICIB1lpjBMXA5qjlLV5avPRA0wRQEUmlJq8sNrMngD4+qxY65x6Ot5lPmsM/8PgeB81sPjAfYMCAAU0Vs8Ma0KNz3YNa/MTr+FGHkIgkS5MtAufcFOfcCJ9XvCAA4TP9/lHLJcDWRtL99rvYOVfpnKvs1atX0zXpoKZWFPkGgugnjomItKbW6hpaBsw1s2wzKwPKgdXA80C5mZWZWRbhAeVlrVSGDsE5/6eIuTgdRrqQWESSLdHpo582sy3ABOBPZrYCwDm3DngQeA14DLjKOVfrnKsBrgZWAK8DD3p5A+GL4wf6pjd2b/54s4I0WUhEkiXRWUN/cM6VOOeynXNFzrnpUevucM4Nds4d55x7NCp9uXNuqLfujkT239FMG14Uk2YGd507stmfEa+lICJyrAJzZXF75Vz4iV05mQ3+K3S8F5E2okDQhsxnrk/kDH9QYZc424iItC4FgjaU4XNU9wsOzaMQISLJoUDQlnTsFpF2KBCB4J5VVakuQlxxp4nG30BEJKkCEQi+t8LvdkhtL7tT87/uyJPH4k0T1fRREUmWQASCVJtaUcTrt80gKxRKdVFERGIoELSSU4YU1r3PMMjNCvl2A+lKYRFJNQWCVpKXffTsv7GZQU118TRcr7ghIsmmQNCGjn2qqN9niYgkhwJBisXrGtKZv4i0FQWCVhJ99p/IDJ+GrQiNKYhIsikQtIFOofpf84ji/Ca3GVXSHYABPTv7rtf0URFJFgWCVnLF5MF170PeQbt3fjYAZxzXu8ntL5pYyhPXnxZzi+qeXbLo3yOX2+eMSF5hRSTQmnxUpbTcc18/k6L8nJj0ovwc1nxjCj06Z3H3k41f7WxmDOkdeyO6zFAGT990RtLKKiKiQNAK/IJARGGX7DYsiYhI09Q11AbiPWUsWmVp/KeUiYi0JgWCNtDYQT5yT6Fbzqrg8etObasiiYjUUSBoA58fO6DJPJmhDIYWdW2D0oiI1KdA0Ab8uoZOGBCeHlpWmNfWxRERqUeBIEUq+oavJQj5PbZMRKQNKRCkiC4IE5H2IqFAYGbfM7M3zOwVM/uDmXWPWnezmVWZ2Xozmx6VPsNLqzKzBYnsvz25cMLAFuXXrSJEpL1ItEWwEhjhnBsJvAncDGBmFcBcYDgwA/iJmYXMLATcA8wEKoDzvbwd3m3HeqWvmgYikmIJBQLn3OPOuRpv8VmgxHs/B1jqnDvknNsIVAFjvVeVc26Dc64aWOrlFRGRFEnmGMHFwKPe+2Jgc9S6LV5avPRWc6C6pulMIiIB1uQtJszsCaCPz6qFzrmHvTwLgRrg15HNfPI7/AOPb2+5mc0H5gMMGND0PPx41n+w95i3FREJgiYDgXNuSmPrzWwecBZwpnN1Q6BbgP5R2UqArd77eOkN97sYWAxQWVl5zEOrmp4pItK4RGcNzQC+BpzjnDsQtWoZMNfMss2sDCgHVgPPA+VmVmZmWYQHlJclUoamtNdAoElDItJeJHr30f8EsoGV3tWzzzrnLnfOrTOzB4HXCHcZXeWcqwUws6uBFUAIWOKcW5dgGRrV3qdpts8wJSJBklAgcM4NaWTdHcAdPunLgeWJ7Lcl8nMy22pXx8QvTp04sP5N6kYU5zN5aNMPsxERORZp/zyCTqGOdc69+utn0rVB8HrkmkkpKo2IBEHaB4L2rmGY6t3IQ21ERFpD2t9rqJ0PEYiIpFzaBwIREWmcAoGISMApEIiIBJwCgYhIwCkQpEh7v9BNRIJDgSDF9DgCEUk1BYIUU8tARFIt7QOB05FWRKRRaR8I2jt1DYlIqikQiIgEnAKBiEjA6aZzKXL1GUN4e/s+zhrZL9VFEZGAUyBIkeLuuTx42YRUF0NERF1DIiJBl/aBQLNHRUQal/aBINle/uY0rp86NNXFEBFJGgWCFurWOZPczFCqiyEikjQKBC3w8FUTU10EEZGkUyBogVH9uze6PidTX6eIdDwJHbnM7HYze8XMXjKzx82sn5duZna3mVV568dEbTPPzN7yXvMSrUB78vzCKbx4y9RUF0NEpEUSPYX9nnNupHNuNPAI8E0vfSZQ7r3mA/cCmFkP4FvAOGAs8C0zK0iwDG3O4T8VqWtOJgV5WW1cGhGRxCQUCJxze6IW86DuCDkHuM+FPQt0N7O+wHRgpXNup3PuY2AlMCORMqTSJaeUpboIIiIJS/jKYjO7A7gQ2A2c7iUXA5ujsm3x0uKlt5q/V+1I+mca5v0rItLxNdkiMLMnzGytz2sOgHNuoXOuP/Br4OrIZj4f5RpJ99vvfDNbY2Zrtm/f3rza+Hjlvd3HvK2ISBA02SJwzk1p5mc9APyJ8BjAFqB/1LoSYKuXPrlB+l/i7HcxsBigsrJS1weLiLSSRGcNlUctngO84b1fBlzozR4aD+x2zr0PrACmmVmBN0g8zUsTEZEUSXSMYJGZHQccAd4BLvfSlwOzgCrgAHARgHNup5ndDjzv5bvNObczwTI0SvcaEhFpXEKBwDn32TjpDrgqzrolwJJE9tsSuw5Ut9WuREQ6pLS/FPbRtR+kuggiIu1a2gcCERFpnAKBiEjAKRCIiAScAoGISMApEIiIBJwCwTHo3yMXgNLCvBSXREQkcQoEzZQZOnqbpOnD+/DgZRO4YNyAFJZIRCQ5Er77aBCZGWPLeqS6GCIiSRG4FsG8CQMbXf/QlSf7pptuOi0iaSpwgcCs8QN6UX5OG5VERKR9CFwgaIrO+0UkaBQImksRQkTSVOACQRM9Q02uFxFJN4ELBE3plpuZ6iKIiLSpwAWCpmb/dM7qxOyRfduoNCIiqRe4QNAcGeofEpEAUSDw4RcGFBpEJF0pEIiIBJwCgYhIwCkQNJOGDUQkXQUuEOiALiJSX1ICgZndYGbOzAq9ZTOzu82sysxeMbMxUXnnmdlb3mteMvYvIiLHLuHbUJtZf2Aq8G5U8kyg3HuNA+4FxplZD+BbQCXggBfMbJlz7uNEy5FMyWo1PPZvk6ipdcn5MBGRVpKMFsGPgJsIH9gj5gD3ubBnge5m1heYDqx0zu30Dv4rgRlJKEO7NKxPPiOKu6W6GCIijUooEJjZOcB7zrmXG6wqBjZHLW/x0uKltxnXjBN03WZCRIKkya4hM3sC6OOzaiHwdWCa32Y+aa6RdL/9zgfmAwwY0LaPhFwwcxgDe+Zx+yOvHS2PLikTkTTVZIvAOTfFOTei4QvYAJQBL5vZJqAE+KeZ9SF8pt8/6mNKgK2NpPvtd7FzrtI5V9mrV69jqdsx65zViUtOKWvTfYqIpMoxdw055151zvV2zpU650oJH+THOOc+AJYBF3qzh8YDu51z7wMrgGlmVmBmBYRbEysSr0bzHetAsKadiki6aq2H1y8HZgFVwAHgIgDn3E4zux143st3m3NuZyuVQUREmiFpgcBrFUTeO+CqOPmWAEuStV8REUlM4K4sPlbqGRKRdKVA0Ey6LExE0pUCgYhIwCkQNJO6hkQkXSkQiIgEnAJBgkaW6F5CItKxtdZ1BGnH4lxR9vBVE9u4JCIiyaVAkKB4AUJEpKNQ1xCQk6mvQUSCK3BHQL/z9xdv8buBatPbiYikg8AFAl0YJiJSX+ACgR+n8CAiARaoQDBhUM9UF0FEpN0JVCAQEZFYCgTNpdFiEUlTCgQiIgEXuEDgd2LvNFYsIgEWuEBwrNQzJCLpSoGgmdRoEJF0FahAYOZ/QNdBXkSCLFCBIBHqGhKRdKVAICIScAkFAjO71czeM7OXvNesqHU3m1mVma03s+lR6TO8tCozW5DI/lvKOf8z+86ZId/8c0b349azK1q3UCIiKZaMFsGPnHOjvddyADOrAOYCw4EZwE/MLGRmIeAeYCZQAZzv5U2pjAz/jp8fzz2BfzmhGNBzB0QkfbXWg2nmAEudc4eAjWZWBYz11lU55zYAmNlSL+9rrVSOeo7lWK5rDEQk3SWjRXC1mb1iZkvMrMBLKwY2R+XZ4qXFS2/31CAQkXTVZCAwsyfMbK3Paw5wLzAYGA28D/wgspnPR7lG0v32O9/M1pjZmu3btzerMiIi0nJNdg0556Y054PM7OfAI97iFqB/1OoSYKv3Pl56w/0uBhYDVFZWJqWDxgxKCnJbtE1mp3CsPL5PfjKKICLS7iQ6a6hv1OKngbXe+2XAXDPLNrMyoBxYDTwPlJtZmZllER5QXpZIGRrjfDr4L5xQWm/51KG9Gv2MLtmdePCyCfzswhOTWTQRkXYj0cHiu8xsNOHunU3AZQDOuXVm9iDhQeAa4CrnXC2AmV0NrABCwBLn3LoEy9BszsXOEGpO1//Ysh6tUyARkXYgoUDgnPtiI+vuAO7wSV8OLE9kv83Vkhk/w/vls27rntYrjIhIO9Va00c7nN9fcTLVtUdSXQwRkTaX1oFgf3VNk3kijYaczBA5ca4wFhFJZ2l9r6FPDtfWW9a1ACIisdI6EISs4cBwbCRQbBCRoEvrQJDRIBA4PXlARCRGoAKBiIjESutAYA1qpxvIiYjESutAkJnRdPW6d85sg5KIiLRfaT19NDer/nTQSE/R3eefwFNvbudTxd349JgOcfNTEZFWk9aBoKHIrKFzRvXjnFH9UlwaEZH2Ia27hkREpGmBCgSaPioiEitQgUBERGIFKhD4XVksIhJ0gQoE6hoSEYkVqEAgIiKxAhUI1DUkIhIrUIFARERiKRCIiAScAoGISMApEIiIBFzaB4JuuUfvLtrwJnQiIpKEQGBm15jZejNbZ2Z3RaXfbGZV3rrpUekzvLQqM1uQ6P6bsvwrkygrzANgRL9urb07EZEOJ6FAYGanA3OAkc654cD3vfQKYC4wHJgB/MTMQmYWAu4BZgIVwPle3lZT3D2Xs0b29crbmnsSEemYEr0N9RXAIufcIQDn3DZi58xHAAAGH0lEQVQvfQ6w1EvfaGZVwFhvXZVzbgOAmS318r6WYDlEROQYJdo1NBSYZGbPmdlfzewkL70Y2ByVb4uXFi+9TehRlSIisZpsEZjZE0Afn1ULve0LgPHAScCDZjYIfC/hdfgHHt/Ds5nNB+YDDBgwoKliNko9QiIi8TUZCJxzU+KtM7MrgIeccw5YbWZHgELCZ/r9o7KWAFu99/HSG+53MbAYoLKyMqFzeTUERETiS7Rr6P+AMwDMbCiQBewAlgFzzSzbzMqAcmA18DxQbmZlZpZFeEB5WYJlaDYNFouIxEp0sHgJsMTM1gLVwDyvdbDOzB4kPAhcA1zlnKsFMLOrgRVACFjinFuXYBlERCQBCQUC51w18IU46+4A7vBJXw4sT2S/IiKSPGl/ZbGIiDQuEIEgMxSuZqeQBglERBpKdIygQ/jypEHsP1TDxRPLUl0UEZF2JxCBIDcrxM2zjk91MURE2qVAdA2JiEh8CgQiIgGnQCAiEnAKBCIiAadAICIScAoEIiIBp0AgIhJwCgQiIgFnrgM8tsvMtgPvJPARhYRvjx0kQatz0OoLqnNQJFLngc65Xk1l6hCBIFFmtsY5V5nqcrSloNU5aPUF1Tko2qLO6hoSEQk4BQIRkYALSiBYnOoCpEDQ6hy0+oLqHBStXudAjBGIiEh8QWkRiIhIHGkdCMxshpmtN7MqM1uQ6vIkwsyWmNk2M1sbldbDzFaa2VvevwVeupnZ3V69XzGzMVHbzPPyv2Vm81JRl+Yys/5mtsrMXjezdWb2FS89bettZjlmttrMXvbq/G0vvczMnvPK/xszy/LSs73lKm99adRn3eylrzez6ampUfOYWcjMXjSzR7zldK/vJjN71cxeMrM1XlrqftfOubR8ASHgbWAQkAW8DFSkulwJ1OdUYAywNirtLmCB934B8P+897OARwEDxgPPeek9gA3evwXe+4JU162ROvcFxnjvuwJvAhXpXG+v7F2895nAc15dHgTmeuk/Ba7w3l8J/NR7Pxf4jfe+wvvNZwNl3t9CKNX1a6Te1wMPAI94y+le301AYYO0lP2u07lFMBaocs5tcM5VA0uBOSku0zFzzj0F7GyQPAf4H+/9/wD/EpV+nwt7FuhuZn2B6cBK59xO59zHwEpgRuuX/tg45953zv3Te78XeB0oJo3r7ZV9n7eY6b0ccAbwOy+9YZ0j38XvgDPNzLz0pc65Q865jUAV4b+JdsfMSoDZwC+8ZSON69uIlP2u0zkQFAObo5a3eGnppMg59z6ED5pAby89Xt077HfidQGcQPgMOa3r7XWTvARsI/zH/TawyzlX42WJLn9d3bz1u4GedKw6/ztwE3DEW+5JetcXwsH9cTN7wczme2kp+12n8zOLzSctKFOk4tW9Q34nZtYF+D3wb865PeETQP+sPmkdrt7OuVpgtJl1B/4A+D1wO1L+Dl1nMzsL2Oace8HMJkeSfbKmRX2jTHTObTWz3sBKM3ujkbytXud0bhFsAfpHLZcAW1NUltbyoddExPt3m5cer+4d7jsxs0zCQeDXzrmHvOS0rzeAc24X8BfC/cLdzSxy4hZd/rq6eeu7Ee5C7Ch1ngicY2abCHffnkG4hZCu9QXAObfV+3cb4WA/lhT+rtM5EDwPlHuzD7IIDywtS3GZkm0ZEJkpMA94OCr9Qm+2wXhgt9fUXAFMM7MCb0bCNC+tXfL6fn8JvO6c+2HUqrStt5n18loCmFkuMIXw2Mgq4FwvW8M6R76Lc4EnXXgkcRkw15tlUwaUA6vbphbN55y72TlX4pwrJfw3+qRz7gLStL4AZpZnZl0j7wn/HteSyt91qkfPW/NFeLT9TcJ9rAtTXZ4E6/K/wPvAYcJnApcQ7hv9M/CW928PL68B93j1fhWojPqciwkPpFUBF6W6Xk3U+RTCTd1XgJe816x0rjcwEnjRq/Na4Jte+iDCB7Yq4LdAtpee4y1XeesHRX3WQu+7WA/MTHXdmlH3yRydNZS29fXq9rL3Whc5NqXyd60ri0VEAi6du4ZERKQZFAhERAJOgUBEJOAUCEREAk6BQEQk4BQIREQCToFARCTgFAhERALu/wONqv3FtqvHMAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "agent = QLearningAgent(env)\n",
    "\n",
    "# 训练\n",
    "episodes = 5000\n",
    "episode_rewards = []\n",
    "for episode in range(episodes):\n",
    "    episode_reward = play_qlearning(env, agent, train=True)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    \n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "# 测试\n",
    "agent.epsilon = 0. # 取消探索\n",
    "\n",
    "episode_rewards = [play_qlearning(env, agent) for _ in range(100)]\n",
    "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),\n",
    "        len(episode_rewards), np.mean(episode_rewards)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 双重 Q 学习"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DoubleQLearningAgent:\n",
    "    def __init__(self, env, gamma=0.9, learning_rate=0.15, epsilon=.01):\n",
    "        self.gamma = gamma\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epsilon = epsilon\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q0 = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        self.q1 = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        \n",
    "    def decide(self, state):\n",
    "        if np.random.uniform() > self.epsilon:\n",
    "            action = (self.q0 + self.q1)[state].argmax()\n",
    "        else:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        return action\n",
    "    \n",
    "    def learn(self, state, action, reward, next_state, done):\n",
    "        if np.random.randint(2):\n",
    "            self.q0, self.q1 = self.q1, self.q0\n",
    "        a = self.q0[next_state].argmax()\n",
    "        u = reward + self.gamma * self.q1[next_state, a] * (1. - done)\n",
    "        td_error = u - self.q0[state, action]\n",
    "        self.q0[state, action] += self.learning_rate * td_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均回合奖励 = 882 / 100 = 8.82\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD8CAYAAAB6paOMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8VfWd//HXJzuQAAk7ISFhJyICBnBl312otla0Vaptca9V21FLW1sdrWN3Z+yCHX6/6bTW4tSFsbhgXdsZZVFUUJCIKJuigIBASCDf+eOeJDfJ3ZKb5Obe834+Hnnk3u/5nnO+35ub8znn+/2e7zHnHCIi4l9piS6AiIgklgKBiIjPKRCIiPicAoGIiM8pEIiI+JwCgYiIzykQiIj4nAKBiIjPKRCIiPhcRqILEIuePXu6kpKSRBdDRCSprF279hPnXK9o+ZIiEJSUlLBmzZpEF0NEJKmY2fux5FPTkIiIzykQiIj4nAKBiIjPKRCIiPicAoGIiM8pEIiI+JwCgYiIz/k+EBypOs6RquPNWqf6eA0HKqsBqKlxfHq4KmS+fYeq2PPZUR55bTuHq45RWX2c3QcrqTpWU5fnQGU11cdreGnzxzy3aXfI7Rw6Glg3eJ3gbew7VEVNjWPZ6m0N0g9WVnP0WKB+h44ei1qvvYeq6rbnnMM5x+6DlXx88GjI9WtqHPu8dY57+z92vKZJvkNHj1Fyy19556OD7D9Szc5Pj/Dsxo+abGvFm7t4asOHhHt8atWxGj7YczjkstqyPPraDj5rVNYdnx7hL2u3A1BZfZzDVceoOlbDijd38f6eQzy5fhe7D1bW1bv2cwi2+2Al+49UU1l9nH2HqjhQWc22vYc5eqz+77L/cHXI+gdrvO2qY/XfpWj+vvkT3vvkUMhlz7z1EW/tPADAns+OxrS9UOWprK7/rhyorGb3wUpqahw7Pj3Co6/t4Njxmrq/+aYPD7Lqvb1Ntnm46hhv7TzAn1Z9wLLV2/jtC+/WLTtQWc1j63bUvXeu/jsUTsXugw0+6yNVx/lwfyWHjh7jYGXgM9+29zDVQZ/93kNVVFYfZ9vew2H/pqFs/PAAj63bwcOvbm/wPTx67Divb/s07Hezsvo4D63ZVnc8CbW/vYeq2PnpEf729kcN0hpvc89nR3l240fs/PRITGVuDZYMzywuLy93bXVD2eDvrOB4jWPr3WfFvM5Vf1jLE+s/ZOvdZ/HjpzZy33PvcsvcEVw5eTCV1cf54ysfMKp/Vy5c8nLYbWy9+yxuffhN/rTqgwbpI/rm0bdbDoXdO5Gdkc5FE4qY+fMX6ZmbxSefVbF43kjuXPF2+Pr06kJpz1zmjurLTQ+9DkB+50z2Ha5mYI/OvL/nMJnpRvVxx0UTirhm6hD2H6nmrHv/3qQcGz88GPNnEsq44u70zsvhyQ0fRsz3xfIBLFuzvUl6XnYGB2MIYOEM75PHpo9C16H2MwmntGcXThvcg617DvHh/kre/Tj0ARjgc2P6U9Alm6X/eC9snh5dstjjHRzOH1vIF8cXsSDo+3HbOWUsPLWEQd9ZAcCwPrm889FnfGP6UO5/cQtHqpt3shKssHsnDhyp5uGrT+P17fv5lve9CHb99KGsfX8ff6/4BIBpI3rz7MbQJyYAnx83gL+8Gvib/ebL4+iVl8OBI9Xc/vhbYYNVY/PH9OexdTtDLivr15W3dh2IaTuxmlhawCte4Oqdl82sE/rwh5c/iLJWU3eddyLfeeRNAK6YPIjfvrAlZD6zQN7xJQX89Y1d/PyZdxosHz2gG29s3x9xX4vnjeSs0f3o371Ts8sZKIOtdc6VR82XqEBgZnOAXwLpwO+cc3eHy9uageCdjw4y+xcv8uK3p1JU0JmSW/4K0CQQVOw+yNr393Hh+GL+/e/vccfjb/H7yycwaVivunWe+9YULr7/ZXbtr6zbxln3vsSGndG/wOUD81nz/r5WqZOIpLbmnKgGizUQJKRpyMzSgfuAuUAZcJGZlbXHvpet3oZz8OT6yGepM3/+Ijf/JRD173j8LQAuXbqKDTvrI/jUnzxfFwQAPj54NKYgACgIiEiHkag+gglAhXNui3OuCngQmJ+gsgBwzr82bBoJd6HUuAml4Todv5lNRKSxRAWCQmBb0PvtXlrCvLkjcludiMRm8rD6yS7PPal/yDxDe+fy9A2Tom7r/kvL+e0lJ7N43kgg0CZ/+emlrVPQZthy1zze/MGssMvf+ee5vH37HP75c6N48wezIuZt7DdfHtcaRYxLogKBhUhrcDptZovMbI2Zrfn4449bvQAPv7YjeqZm+t8te1p9m9K+xhV3T3QRmu26aUN470fzwi7PSm/6b55m8P8uG9+g7bmweyeumjI45v1+e/bwJmlb7z6L/7h8Qt37GWV9Qq678sbJDOuTF3Ufhd07MfuEvnx90qDAAIu5I/nKaSUR15lQUtDg/aJJgyLmv+SUgTzw9YkR86SlGXk5mWGXZ2Wk0SkrnS+fMpC8nEzycjK5OsxnObBH5wbv54zqx/BGn8VPLzgpYnlaW6Kmod4OFAW9HwA0GD7gnFsCLIFAZ3FrF+DtXQd4f0/00Q3NGcJ1w5/XxVMkSYBvTB/KvX/bXPd+6vDevPrBp62y7UE9u7AlxAiaa6cO4Vuzh9cNOgBYdsWpfPG3/9vsfcTSiVheks//vNvwJOX66cOYOrx3g7QeuVlhm0RDufz0Uj47eoxlq7ex51AVpw7qUbfsH7dMIzOt4fne+JJ8Vm8N3Tf26y+N46o/vtokvUduVpO04h6defy6M8jvksXpdz/LD84p48XNn3DywHwmlhbwwjsfs2pr/bDW88YWkpedwU9XNhy18/y3plDSs0vd+2dvmkxGWhqTfvwcAP/51QkcrDxGXk79YfLKyYP5zQvv0r9bDju9/sEXvz01ZJ2KCjrXfU7jBnZn2ojevPPRZ4wp6s4zb33E135fPwDmdwvL+e83dnLPk5sAKOnZOeQ220qiAsFqYKiZlQI7gAXAxe1diOowY753H6zvAF64dFXM26tRF0HC/eLCMXwzQkB+9655DPaGaALcMKNhILhm6pAmB4xYTR/Rm79t3E2nzPSIwz1vmDms7nV2Rhov3TyVHl2yQ+ZdNGkQS17cQtecDA5UBobS/uGrEzljaM8meX/8hdE8um4H/6hoeNA/e3R/Hvj6KVTsPsj5v/ofDlQeIzen6b/+/ZeWs25bIAgunjeSu554uy4wPHvTZC5c8jLZGWls3xc4OeqUlc7Nc0Zw8YRilq3ZxjVTh9RtqzBouOPZ3vDHb84YStn3n2qwzysnD2bbvsOMKuzWIP0H55QxZXhv+nTNCfm51OavDYRfCWouGlucT7dOmYwq7EZh904UFXRmZL+unDumP/27d+LD/ZX0yssmJzO9wTYH9coF4LXvzaTGOXrkNv2bXDttCI++toOfXziG3l1zKC7oTHpaqAYOuODkAew7XMXlp5fW7WtMUeCKc0ZZH7Iz0jjq3fdTVNCZq6cMqQsE2RmB/HnZGfxuYdRBP3FLSCBwzh0zs2uBpwgMH13qnNuQiLIEO3T0GF2yM7hpWf04631hbhaTjmlEv6bNDaU9u9SNbQ/+n73n86Mxq08Y3iePtDRj691nUXWshmHffSLsfs4fW8iK9buorA78I984cxjfmD4UgKc3fMii/1xb19ZZ2L0TK75xJifd/jRA3YHj9dtmkZludM4K/BvW3ucR7Na5I/jqGaVc96fX6m7eqglz2n5BeRFnje7X4GB7+emlXDQhcPE9pHcer35vJn9avY0F4+svyGs/nz5dc5h9Ql9e+c50+nTN4euTBtVdtQzqlcvqxTMA+O/Xd/LE+l116xcVdOamWU2biWr928X1beD/fe0ZHK6qvzfklrkjAPjoQP3J1xlDerLwtJIGf5vmSE8zvnZm0+aggT261JU3kvwuTa9CauVmZ/Dyd6bHVI6M9DSunjIk7PKVN0xm44dNRxkWdu/EqMJu3PP50cw5sS9dIzRJtZaEPaHMObcCWBE1YzuqOlZDRvrxuptqJLl87YxS+oY4g8zOqG8jDz64pDU6kws+7mSm17954GsTufh3r9S9v3HmMK6bNoQVQQfD4LPC2n3UjiLLzkyjW+em/8zdOjVMe+HbU/ntC+/yoyc2NthWn645DXrQIl141gaVWtdPH9qgzhnpaVxyysAGeZ765qQGwSXcWXitc07qzzlhOoGjOXFAt5DpfbrmcP+l5UwoKQj5WaWi4h6dKW7UX7DxjjmkeX+vL44vCrVam/D9FBPB3tixn+HffbJBO+knn+mKIFFqz7Bj9d2zy+jeuenZXEZ6wwP+xROLKevXlbNH9wu7LTNj+ohAG/qoRgev88YWYmaUD6zvlMwICgThhhHfe9FYfhKlE7C2ebFnbha/XDAmKN2FfB1Kl6xAs8K1U4fEdFDNykhr0kwSrPaKoq3NLOvjmyAQTk5mOlkZ7X9YTopnFreXx9pgJJFEd97YQh55bQeDenVhS9BUDjfOHNag/T6cUYVdycsOfwBJb9TEcNd5J4bMl9eo3fxfLx7Lu7sPNbk0z80O5PvNJSfz9q4DPLn+QxZGGMlSu/dwQymDzSzrzb88uZH/f9mEBu3mJwfdiR7tfpWffvEkrvzDqwzrG31UTjQtvaNVkosCQZC2GFIq0V09ZTCPvLaj7pI42D9/bhTffXR9xPUfv+7MiMsbNwE1Vjv3UuOz9c5ZGU2aMp65cXJdG3JudgbjSwoY32i4Ym1TzID8zuz5rIqb54yIuP9gQ3rnhTz4fnv2cPYequKhtdsZkB+5jXvOqH48cf2ZjOzXNeb9ir/5vGmoZZ1R0rqKe3RmWJ9cbj/3BHo06qj7cqP27JZYeGoJ0HAkSzDzvgehRok0NqR3bsz7zc5I480fzmbWCX1jXiecjPQ07vnCaF76p6kxjb9XEJDm8Hkg0HjPjiA7I52nb5jMaUN68qWJxRHzXjShuMkNOZHzF9Xl7xliTHqw1jotaKvTCzOLOuJFpCV80zS08cMD/PKZzTwRZbI5SaxIofn+S8s5c2hPdu2v5FfPVfDQ2qZTVzdl9aN4wmeJ6slvntmgQ1gklfjmimDOL15qEgQeflV9AvG67Zz6SWMnBc0x01K1Y71DmVnWh5zMdEp7duHHMd+C7+qO89Humo00bH1E364M6d28zlddb0qy8E0gCOVXz78bPZNEdNGE+qacr5xW356/Ksabbu6/tOFdk7VDNuPx4KJTOGVQfQdu7QHetdOhuYX3QYkkjK8DgbSu2jPuqcN70TvKTUkAg3p1YcbIhgf+1jiInjKoR9ANT1bXGRz1iqCVWvc1G7kkGwUCiUukA/fr3284Fe/fbprc4P2zN01p8TQCsXP1VwRxNA2JpDLfdBZL22t8IG18l+jgXtGHXjbnrHzpV8p5aXP06UBqn/d6QfmAmLcdDwUUSTYKBBLWheVF/HnNtoh5Yj1w3/G5Ua1RpAamjejDtBGh57uvZxR0yWLLXfN0gBYJQ01DEvIBIwADW3FO9MYTnXWKMLcNQNcQ0yS3TKA9KC3N2qEZqtGe1VkgSUKBQJrMsdMcwcfWsUX5ZKYbV06O/pSrtd+b0eJ9dnS68pBko0AgDRp3ll1xaovXze+SxeY75zEx6ElV4dcLfbRs/SGe0Y/KrX3c7uJNBR1tOmeRjkJ9BNLgFHZCaQFnndiPv765K8IKrbrLMMvrM9wwYxidstrunKW1Q8+E0gJ+9sWTmDsq/DTXIh2JAoFw/thCvhc0w+eAgsAom9YaVx+v62c077kEEJjaetV7e7kx6LGQ7cXMOH9c+4xQEmkNahoSOmdF7rgF6O4NBf1mo4Ny7dTRzZlqGdq+Hb1zVga/XDCWXnnRZxStfeh6qGmwRfxAVwTSVFBbyZJLTiYrI41HX9vBo+t2NnmSVe0zfpPZr788ju37jiTkyVAiHYECgc/NGNkn4rDK2rn0Tx3cg5tmDee5TbtbZb8dpdkJAlcPsczxL5KqdArkU3efH3hcY+MHwYSTnZFOUUFn8oOeCbzkkpNbvH+1woh0HAoEPtXSkTLBD3yP58lbigMiHYcCgU/V3vTa3DPz1ro7N9x2dDOuSPuLKxCY2QVmtsHMasysvNGyW82swsw2mdnsoPQ5XlqFmd0Sz/6l5Wpv3Ap1PK49Frdl8020TavpSKT9xHtFsB44H3gxONHMyoAFwAnAHOBXZpZuZunAfcBcoAy4yMsr7SgjzYLOvBNzxNWBXqTjiCsQOOfeds5tCrFoPvCgc+6oc+49oAKY4P1UOOe2OOeqgAe9vNKO+narn/qgvQ/Ixd7D19t7AjgRCa+tho8WAi8Hvd/upQFsa5Q+MdQGzGwRsAiguLg4VBZppkE9u7Dlk0OYxdZZ3BaH6oeuPJU3tu9vgy2LSEtFDQRm9gwQanjIYufcY+FWC5HmCH0FEvKY5JxbAiwBKC8vVxdiGwn+Q33tjNI231+frjnMLAs/GVumd1PXiYXd2rwsIhIQNRA451oyX/B2oCjo/QBgp/c6XLq0p0bDc4LvDh7SO/AksYE9urRrkQByszN4+OrTdIOXSDtqq6ah5cADZvYzoD8wFFhF4AR0qJmVAjsIdChf3EZlkBiEaqq/4OQBDO+Tx0lF3ePe/qrF06k+3rwLunHF+XHvV0RiF1cgMLPzgH8FegF/NbN1zrnZzrkNZrYMeAs4BlzjnDvurXMt8BSQDix1zm2IqwbSIpEOzWbWKkEAoHee5uQX6ejiCgTOuUeAR8IsuxO4M0T6CmBFPPuV1tOR5vwRkcTQncU+ZJju4BWROgoEKeQnF5wUc970tMCVgKZeFhFNQ51Cpg7vFXPeC8oH8P6eQ3xjevOf/iUiqUWngymkR27kp3FNHFQABEYKZWeks/isMvJyMtujaCLSgSkQ+MRd553IokmDE10MEemAFAh8ojC/k8YHiUhICgQpbmJpQd1rDRQSkVAUCFJM45FDoUYF6cpARIJp1FCKGVvcOncER3JSUXdOLOza5vsRkfahQJBi2uNs/7FrTm+HvYhIe1HTkI843U4sIiEoECSp9T+cHTJdh3oRaS4FgiTUOSud3OzYWvWyvc7iNPUQi0gYCgRJ6IwhPQE4f1xhk2WNj/c/On80V00ZzOmDe9bn0fOCRSSIOouT0L0XjQWge6esJssaNw31ysvm5jkj2qFUIpKsdEWQhHIy0wH49uzh3HZOWYJLIyLJToEgiXXKSuey02N/4Lw6kkUkFAWCJLPyhkkRl0dq/a9dlqGeYxEJoj6CJNI1J4OhffIi5ol01l/aswtXTRnMgvFFrVswEUlqCgQ+YmbqOBaRJtQ0lEQ07FNE2oICQRLRFBEi0hbiCgRm9mMz22hmb5jZI2bWPWjZrWZWYWabzGx2UPocL63CzG6JZ/8iIhK/eK8IVgKjnHOjgXeAWwHMrAxYAJwAzAF+ZWbpZpYO3AfMBcqAi7y8EgM1DYlIW4grEDjnnnbOHfPevgwM8F7PBx50zh11zr0HVAATvJ8K59wW51wV8KCXV0REEqQ1+wguB57wXhcC24KWbffSwqW3qV37j7T1LkREklbU4aNm9gzQN8Sixc65x7w8i4FjwB9rVwuR3xE68ITsATWzRcAigOLi4mjFjOj5TR/HtX4yKejcdP4hEZFIogYC59yMSMvNbCFwNjDd1Q9r2Q4E37U0ANjpvQ6X3ni/S4AlAOXl5XENl9n80WfxrJ5U8rtk8fr3Z3HS7U+Tk6lBYSISXVw3lJnZHOBmYLJz7nDQouXAA2b2M6A/MBRYReBKYaiZlQI7CHQoXxxPGWKx9B/vtfUuOpRunTP5y1Wn0r97p0QXRUSSQLx3Fv8bkA2s9Ea0vOycu9I5t8HMlgFvEWgyusY5dxzAzK4FngLSgaXOuQ1xlsE3mjNo6OSBBW1XEBFJKXEFAufckAjL7gTuDJG+AlgRz379YsklJ3PtA69RdbwGAN1PJiJtQY3IHVhmRhrv3Dk30cUQkRSnQJBEdD+ZiLQFBYIOLF9DQUWkHaR8IPj75k8SXYQW+f3lExhT1L1BWk5GeoJKIyKpLOWfR7DvcFWii9Aik4b1apL2wNcnhsw7uFcXyjVKSERaKOUDQUd2xeRB/PaFLTHnH9QrN2T6326a0kolEhE/SvmmoY5sVlmomTtERNpXygeCjjzSpiOXTUT8I+UDwcOv7kh0EUREOrSUDwTPbtyd6CKIiHRoKR8IREQkMgUCERGfUyAQEfE5BQIREZ9TIEggjR4VkY5AgSCB9HgBEekIFAhERHxOgSCB1DQkIh2BJp3rQEp6dGbrnsNN0heeOpAxxd1DrCEiEj9dEXQgl55aEjL9h/NHcd7YAe1bGBHxDQWCDuTyM0oTXQQR8SEFAhERn1MgEBHxubgCgZndYWZvmNk6M3vazPp76WZm95pZhbd8XNA6C81ss/ezMN4KiIhIfOK9Ivixc260c24M8DjwfS99LjDU+1kE/BrAzAqA24CJwATgNjPLj7MMScv0ZBoR6QDiCgTOuQNBb7tQf7PsfOD3LuBloLuZ9QNmAyudc3udc/uAlcCceMogIiLxifs+AjO7E7gU2A9M9ZILgW1B2bZ7aeHSQ213EYGrCYqLi+MtZofknCaZEJHEi3pFYGbPmNn6ED/zAZxzi51zRcAfgWtrVwuxKRchvWmic0ucc+XOufJevXrFVhsREWm2qFcEzrkZMW7rAeCvBPoAtgNFQcsGADu99CmN0p+PcfspR30EItIRxDtqaGjQ23OBjd7r5cCl3uihU4D9zrldwFPALDPL9zqJZ3lpvvXQlacmuggi4nPx9hHcbWbDgRrgfeBKL30FMA+oAA4DlwE45/aa2R3Aai/f7c65vXGWIamNLylIdBFExOfiCgTOuc+HSXfANWGWLQWWxrNfERFpPbqzWETE5xQIRER8ToGgFZX169qs/KHGDGVlpJGbrcdEiEj7USCI0Q/PPSHi8jkn9OWXC8bEvZ83fzCLV783M+7tiIjESqeeMUpPizzmPyczjbQoeWKRnZEe9zZERJpDVwQximUyiH7dcqLmuX56/a0XJT27AIEgc+vcES0tmohIXBQIYhRuXqChvXPrXnfOymDr3Wdx3bQhIfM+fcOkBoGgW6dMAN69ax5XTB7ciqUVEYmdAkGMws0PV5sc63QRmmZORDoa9RHEqCZMJGjuDKJpBmeP7seC8ak5o6qIJB8FghjVNON4H+nawMz4t4vHRcghItK+1DQUo6z0yE0/eraAiCQrBYIYXaimHBFJUQoEMcrKiPxR6dkCIpKsFAjiFKpBKFwjUY5uFhORDkiBoA1dNWUwv/5SfcdwcY/OCSyNiEhoCgTxCnH6X9tIlJORztwT+7VrcUREmkuBIE61Z/mjCrsluCQiIi2j+wiC9O+Ww879lc1a5+TifG6dO5JhfXKjZxYR6YB0RRBkaJ+8Fq03vG+eRg2JSNJSIAgS7Zaw/t1yyG40jDTSOk4zC4lIElDTUJBodwe/dPM0AAZ/Z0XkDenqQESSiAJBM4R6OI1mlhCRZNcqTUNm9i0zc2bW03tvZnavmVWY2RtmNi4o70Iz2+z9LGyN/beWaSN6t86GFB1EJInEfUVgZkXATOCDoOS5wFDvZyLwa2CimRUAtwHlBJrX15rZcufcvnjL0RqK8lv3hi+LOA+piEjH0BpXBD8H/omG/abzgd+7gJeB7mbWD5gNrHTO7fUO/iuBOa1QBhERaaG4rgjM7Fxgh3Pu9UbDJwuBbUHvt3tp4dIT7opJgzhjaM8m6QVdsijr1zXsepFGBmnUkIgkg6iBwMyeAfqGWLQY+A4wK9RqIdJchPRQ+10ELAIoLm77KaBvnTcSgJzMNCqra+rSX/3ezOZvTKOGRCSJRA0EzrkZodLN7ESgFKi9GhgAvGpmEwic6RcFZR8A7PTSpzRKfz7MfpcASwDKy8s77Kl1TqZmFBWR5NbipiHn3JtA3TAbM9sKlDvnPjGz5cC1ZvYggc7i/c65XWb2FHCXmeV7q80Cbm1x6RPk4atPY9vew3yw5zCXnV6S6OKIiMSlre4jWAHMAyqAw8BlAM65vWZ2B7Day3e7c25vG5WhRWIZ+TmuOJ9xxfnRM4qIJIFWCwTOuZKg1w64Jky+pcDS1tpvR9S3aw4AfbzfoG4DEem4fH9n8RPXn9mgnb/2gP34dWeQEeWB9eEsGF9Ej9wsZo7sA8AjV5/WICiIiHQkvg8EI8MMDR3cK5dOWS3rCE5LM2afUD/QaqyakUSkA/P17KOhppTQ7BAi4je+DgS/CnqecGNq0xcRv/B1INA9ACIiPg8EoVx+RikAGSGmnBYRSUW+7yxu7OY5I7h5zohEF0NEpN3oikBExOd8GwgumtD2E9mJiCQD3waCH51/YqKLICLSIfg2EIiISICvAsHFE9UcJCLSmK8CQc/c7EQXQUSkw/FVIBARkaYUCEREfE6BQETE5/wVCDS1qIhIE76bYuLZmyZjmlpURKSOvwKBGYN65Sa6FCIiHYqvmoZ0HSAi0pSvAoGIiDSlQCAi4nNxBQIz+4GZ7TCzdd7PvKBlt5pZhZltMrPZQelzvLQKM7slnv2LiEj8WqOz+OfOuZ8EJ5hZGbAAOAHoDzxjZsO8xfcBM4HtwGozW+6ce6sVyhGVBo+KiDTVVqOG5gMPOueOAu+ZWQUwwVtW4ZzbAmBmD3p52yUQiIhIU63RR3Ctmb1hZkvNLN9LKwS2BeXZ7qWFSxcRkQSJGgjM7BkzWx/iZz7wa2AwMAbYBfy0drUQm3IR0kPtd5GZrTGzNR9//HFMlYlGw0dFRJqK2jTknJsRy4bM7H7gce/tdqAoaPEAYKf3Olx64/0uAZYAlJeXt0rzvm4oFhFpKt5RQ/2C3p4HrPdeLwcWmFm2mZUCQ4FVwGpgqJmVmlkWgQ7l5fGUQURE4hNvZ/E9ZjaGQPPOVuAKAOfcBjNbRqAT+BhwjXPuOICZXQs8BaQDS51zG+IsQ8w055yISFNxBQLn3CURlt0J3BkifQWwIp79iohI69GdxSIiPqdAICLicwoEIiI+56tAoOGjIiJN+SsQ6JYyEZEmfBUInKadExFpwleBQEREmlIgEBHxuZRWOITnAAAHEUlEQVQOBJXVxxNdBBGRDi+lA8HxGvUJiIhEk9KBoPFwUY0aEhFpKqUDgYiIRJfSgUCzjYqIRJfagSDRBRARSQIpHQhqdEkgIhJVSgcCxQERkehSOhCobUhEJLqUDgRqGhIRiS6lA0HjMKBpqEVEmkrpQND4imBgj84JKomISMeV0oEgNzujwftzT+qfoJKIiHRcKR0IcjLTG7w3tQ2JiDSR0oFARESiizsQmNl1ZrbJzDaY2T1B6beaWYW3bHZQ+hwvrcLMbol3/yIiEp+M6FnCM7OpwHxgtHPuqJn19tLLgAXACUB/4BkzG+atdh8wE9gOrDaz5c65t+Iph4iItFxcgQC4CrjbOXcUwDm320ufDzzopb9nZhXABG9ZhXNuC4CZPejlVSAQEUmQeJuGhgFnmtkrZvaCmY330guBbUH5tntp4dLb3NM3TGqP3YiIJJ2oVwRm9gzQN8Sixd76+cApwHhgmZkNgpBPgHGEDjwhb/81s0XAIoDi4uJoxRQRkRaKGgicczPCLTOzq4CHnXMOWGVmNUBPAmf6RUFZBwA7vdfh0hvvdwmwBKC8vDzuuSI024SISGjxNg09CkwD8DqDs4BPgOXAAjPLNrNSYCiwClgNDDWzUjPLItChvDzOMsTEaQY6EZGQ4u0sXgosNbP1QBWw0Ls62GBmywh0Ah8DrnHOHQcws2uBp4B0YKlzbkOcZYiJnlcsIhJaXIHAOVcFfDnMsjuBO0OkrwBWxLPf5vjG9KGs2/Ypw/rkttcuRUSSSrxXBB3ejTOHRc8kIuJjmmJCRMTnFAhERHxOgUBExOcUCEREfE6BQETE5xQIRER8ToFARMTnFAhERHzOXBLMxmZmHwPvx7GJngTmQPITv9XZb/UF1dkv4qnzQOdcr2iZkiIQxMvM1jjnyhNdjvbktzr7rb6gOvtFe9RZTUMiIj6nQCAi4nN+CQRLEl2ABPBbnf1WX1Cd/aLN6+yLPgIREQnPL1cEIiISRkoHAjObY2abzKzCzG5JdHniYWZLzWy39zS42rQCM1tpZpu93/leupnZvV693zCzcUHrLPTybzazhYmoS6zMrMjMnjOzt81sg5ld76WnbL3NLMfMVpnZ616df+ill5rZK175/+w96hXvcbB/9ur8ipmVBG3rVi99k5nNTkyNYmNm6Wb2mpk97r1P9fpuNbM3zWydma3x0hL3vXbOpeQPgUdhvgsMIvAs5deBskSXK476TALGAeuD0u4BbvFe3wL8i/d6HvAEYMApwCteegGwxfud773OT3TdItS5HzDOe50HvAOUpXK9vbLneq8zgVe8uiwDFnjpvwGu8l5fDfzGe70A+LP3usz7zmcDpd7/Qnqi6xeh3jcCDwCPe+9Tvb5bgZ6N0hL2vU7lK4IJQIVzbosLPFLzQWB+gsvUYs65F4G9jZLnA//hvf4P4HNB6b93AS8D3c2sHzAbWOmc2+uc2wesBOa0felbxjm3yzn3qvf6IPA2UEgK19sr+2fe20zvxwHTgP/y0hvXufaz+C9gupmZl/6gc+6oc+49oILA/0SHY2YDgLOA33nvjRSubwQJ+16nciAoBLYFvd/upaWSPs65XRA4aAK9vfRwdU/az8RrAhhL4Aw5pevtNZOsA3YT+Od+F/jUOXfMyxJc/rq6ecv3Az1Irjr/AvgnoMZ734PUri8EgvvTZrbWzBZ5aQn7XqfyM4stRJpfhkiFq3tSfiZmlgv8Bfimc+5A4AQwdNYQaUlXb+fccWCMmXUHHgFGhsrm/U7qOpvZ2cBu59xaM5tSmxwia0rUN8jpzrmdZtYbWGlmGyPkbfM6p/IVwXagKOj9AGBngsrSVj7yLhHxfu/20sPVPek+EzPLJBAE/uice9hLTvl6AzjnPgWeJ9Au3N3Mak/cgstfVzdveTcCTYjJUufTgXPNbCuB5ttpBK4QUrW+ADjndnq/dxMI9hNI4Pc6lQPBamCoN/ogi0DH0vIEl6m1LQdqRwosBB4LSr/UG21wCrDfu9R8CphlZvneiIRZXlqH5LX9/jvwtnPuZ0GLUrbeZtbLuxLAzDoBMwj0jTwHfMHL1rjOtZ/FF4BnXaAncTmwwBtlUwoMBVa1Ty1i55y71Tk3wDlXQuB/9Fnn3JdI0foCmFkXM8urfU3g+7ieRH6vE9173pY/BHrb3yHQxro40eWJsy5/AnYB1QTOBL5KoG30b8Bm73eBl9eA+7x6vwmUB23ncgIdaRXAZYmuV5Q6n0HgUvcNYJ33My+V6w2MBl7z6rwe+L6XPojAga0CeAjI9tJzvPcV3vJBQdta7H0Wm4C5ia5bDHWfQv2ooZStr1e3172fDbXHpkR+r3VnsYiIz6Vy05CIiMRAgUBExOcUCEREfE6BQETE5xQIRER8ToFARMTnFAhERHxOgUBExOf+D6Uw/cN1r9x1AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "agent = DoubleQLearningAgent(env)\n",
    "\n",
    "# 训练\n",
    "episodes = 5000\n",
    "episode_rewards = []\n",
    "for episode in range(episodes):\n",
    "    episode_reward = play_qlearning(env, agent, train=True)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    \n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "# 测试\n",
    "agent.epsilon = 0. # 取消探索\n",
    "\n",
    "episode_rewards = [play_qlearning(env, agent) for _ in range(100)]\n",
    "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),\n",
    "        len(episode_rewards), np.mean(episode_rewards)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SARSA($\\lambda $) 算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SARSALambdaAgent(SARSAAgent):\n",
    "    def __init__(self, env, lambd=0.5, beta=1.,\n",
    "            gamma=0.9, learning_rate=0.1, epsilon=.01):\n",
    "        super().__init__(env, gamma=gamma, learning_rate=learning_rate,\n",
    "                epsilon=epsilon)\n",
    "        self.lambd = lambd\n",
    "        self.beta = beta\n",
    "        self.e = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        \n",
    "    def learn(self, state, action, reward, next_state, done, next_action):\n",
    "        # 更新资格迹\n",
    "        self.e *= (self.lambd * self.gamma)\n",
    "        self.e[state, action] = 1. + self.beta * self.e[state, action]\n",
    "\n",
    "        # 更新价值\n",
    "        u = reward + self.gamma * \\\n",
    "                self.q[next_state, next_action] * (1. - done)\n",
    "        td_error = u - self.q[state, action]\n",
    "        self.q += self.learning_rate * self.e * td_error\n",
    "        if done:\n",
    "            self.e *= 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均回合奖励 = 843 / 100 = 8.43\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAD8CAYAAACCRVh7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHoVJREFUeJzt3Xl8VPW9//HXJxthJxB2CAQMIiCLREBQQHakV7SLxWrd2mKttlXbWij1tmr5ldpW2/5uN1r53Z+/tle91bY8WhTBrd7rVQRXqAIRKYusIvsa8vn9MSfDJJmThCSTmWTez8cjj5zzPd+Z+X4nk3nPOd9zvmPujoiISDwZyW6AiIikLoWEiIiEUkiIiEgohYSIiIRSSIiISCiFhIiIhFJIiIhIKIWEiIiEUkiIiEiorGQ3oL7y8/O9b9++yW6GiEiTsmbNmr3u3rmmek0+JPr27cvq1auT3QwRkSbFzP5Zm3o63CQiIqEUEiIiEkohISIioRQSIiISSiEhIiKhFBIiIhJKISEiIqEUEiE27TnMSyV7K5SVlVX9qtfNe4/wzw+PcPP/W82bW/ez9/AJdh88zoZdh1j1/j5KT5dRVua4O//88AhHTpTi7tH7+mD/MfYfPclf3tjOK5s+rHL/pafL4j6+u0d/ysvdnXd3HuRE6WkeW72VsjJn676jvLBhT4V6ZWXOydIyNuw6FF0v/yn/OtvY+/+vjXt5b8/hCo8d+7vy83K6zDld5ry/9wjHTp5m96Hj7D50nHhflevu7DoY2V56ugz3yG2PnzrNhl2HKrQt9jHDVN4e27bNe4/w4sY90fJtHx3loyMn2X3wOLsPHo+2Pbb/le/r4PFTvLfncPS5C1P5b3Xs5OkK2/cfPcn2/cd4Z8dByoLny92jr5f39hzm+KnTvLPjYIX2uDunYl4TOw8c58m3d1Roa7y2l7ep8nMc+3ctrxPvdV5Z+fO0ff8xtu47yuNrtuHu0f7sPHCcfUdOsufQCXYfOh59bb225SPWbj/AweOn2LrvaPT+NgevlefW7+bpdTtxj7SvvE3l/Sn/f4h9PVfnqbU72XHgWHT9ZGnk+T128jSbgr5v+fAoh0+UVnkeYv8eYc/P2u0HeH3LR1Xa5O6s3X6AtdsPRNu+fuehCq+v8seo6W+3+9BxXnpvb4W/e2OyVPuOazObAfwUyAR+6+6LqqtfXFzsibiYru+8v8Ut//q0Afzo6Q0N/ngiDe3aMQX87uUtyW5GszHr/O787e0dANwwti///tLmRm9Di6wMTpSeCYt190yndYu6XRNtZmvcvbimeim1J2FmmcDPgZnAIOBqMxuU3FZVpICQpkIB0bDKAwJISkAAFQICICvTEv6YKRUSwCigxN03uftJ4BFgdpLbJCKSklpkZSb8MVItJHoCW2PWtwVlFZjZXDNbbWar9+zZ02iNa4p+c12Ne5NRkwd24ariXhXKJp5bcf6v//jCmArry28fz7Be7avcV7/81jxw1TC+POmcKts+cUEvZp3fvULZO/fOiC7/+rMjq9xmZJ+8mjsQeOVbk2tdF+CBq4bx+YsLuXpUAQ/fNIrBPdpx+5QiNi+axdN3jKegY6u4t7thbF9evOvS0Pvt2q5FdHndPdMBaJebxQNXDeP1u6dWqPvlSefwxr9O5eX5k5k+uGu17X38lrH897xJ/OXWcVW2bV40i82LZkXXf3HNBdHl6y7qw08+PRyIPJ+P3zI2uu3qUQVV7utzFxfy8E2j2LxoFktuiLyOZgzuxuLPjuQHnzg/Wu/HnxoWXX74plGRQ7KfGsaKO8bz7n0zqtxvuStH9OR7Vwzhp3OGV6g3fsCZ19xTt1/Cf8+bFF3vldcy2s/KYvsK8JXJRfTr3LpC2YzB3QB47e6prLtneoX7BvjVtSNZccd4Vt45IbTdsX+fF++6lPf+12Vx68WWL7xyCG9/dxr3f2Jo6P2uvHMCy28fX+X+fnb1CO67Ykjo7RIt1Sb4i7fvVGXQxN0XA4shMiaR6EY1Bfd/Yih3Pf5WhbLbLj2HqYO6snnRLB5ZtYV5T7xdYfuPPjWM7y97hw+PnOSleZPo0rYFWZkZFPftyIsb93LX9HPpldeSwvnLorc5v1IgdGnbgoduuJDi762ke/tcdhyIDACvvHMCGRlGWZnzv58tAeCHnxzKpr1HuGVifw4cPRXdfX/s5otomZPJtEFdeX7DHqYH/8g3jutLt3a5nNutLRPP7VJhnOiTI3vxxzXbouurFkxm1MJnAOjaLpcnv3oJpaedJ9fuYO/hEzzx2naG9GzPG1v3V3nuLju/Ox+/4Ew4xr5JDejalhe+MZHC+cv4/MWFdO/QkkuK8ln5zi5umdAfM2PmkG7sOnicY6fK+PW1Ixn/w+f4/MWFfGPGuZz77af47Jg+tG6RxWt3TyUr02iXmw3AH74wmn97toSX3vuQ9i2z6dAqB4BFHx9Kx9bryWuVTf/ObTheepp++W04erKUoi5tKegUCa2eHVqyedEs3t97hMfXbOOmiwvj9m3pbeP48dMbmD/zPFrmZNKpTQ6De7SnY+scfve50Xx45ASzh/ekY+tsyhxyMjP46TMbuXZMHwrzI2+ykwZ25defHcmEAZ3Jzc7E3enUugW9O7aqEIYdW+dw26SiKu0AmDu+H4v/vgmAj1/QkweuGl5h++O3jOW5d3fztWkDoq+5gd3aAfCxod25+Jx8rhjRk4PHTgHQMjuTY8Hg9sM3jWL8gM48+Olh3PHomwDcOXUAd04dwL4jJ/nMb17ml9eOjPanXOsWWdw8oR+/fiHSrhlDIq+9A8FjlFtyQzE/eHI937l8EGP757Npz2GOnDhN7+ADxOt3T+XJtTv51p/O/I9lZhhLbxvH29sPcM3oPgBcdWFv7nr8LbIzjX8Z1oP+ndvww+XrATinS5vobb93xRC+/ee1LLjsPC4f1oNdB49z95/XAvDEl8by+pb9bNhZ/ckTDSWlBq7N7CLgu+4+PVifD+Du3w+7TWMPXDeWS4ryeXHjmbOr/s8NF3Ljv78aXX/2axOY9OMXgDOfqq761f+wavM+IPIp9pVvTYnWLw+JLm1bsPvQCVbcMZ6irm05ePwUJ06V0bntmX/0yjbvPcLEHz0PwMaFMyla8GR02/vfvwyzM9n+5tb9vL7lI24YV/UNq7ITpafZc+gEvfLif1KvrPxvkt+mBau/PYVNew7Tr3ObKtvjfcos988Pj9ArrxWZGYk/llsbR06U8uCKDXx9+rnkZif+0EFtlJU5W/YdpW+lN9TqzPjJ33l35yH+9pWLGdyj4geJ17d8RLuWkcCrrZLdhzl8opThvTtUW2/comfZvv8YL951Kb07tuKlkr185revMKZfRx6Ze1GtHy+eZ9/dxYjeeeS1zqn1bQ4cPcWwe58mM8NC9zAq27jrELnZmdHAiWfHgWNc9P1n6d4+l/+Zf3Z7ymFqO3CdansSrwJFZlYIbAfmAJ9prAc/fKKU3KwMzv/u0431kKFa5VR6wwje0zIMXvhG5B9i8sAuXFjYsUqdi/p14r4rBse934nndub+T545PNAuNxtyq29L7JtFVqU319iAABjWuwPDavjHLtciK7PWAVHxdpGjpP3O4k2nXJ9OtX/jawytW2Tx7Y+l1LkZZGTYWQUEwN0fG8S8J96KGwQjCmp/qLBc7Kfq6pR/yM1IQOhPGlj9Yb942rfK5g+fH13tG35lRV3b1linVXbkrXps//yzblN9pVRIuHupmd0GLCdyCuwSd1/XWI8/5DvLmTG4W3QXNtHO7dqW9SHn2984rpDl63ZF13ODAaovXNIv+gJ86IYLK9zmiuE9WfX+Ph789HC6ta/4zj+8IPLGPeW8s3/hQ+TQyPqdhzAznv/6xOieRTI8/LlRSXtsiW/cOfm8eNekmis2sPJLFsozovxDS4Ylb09x7DkN/0bevlU2z3xtQnRMpjGlVEgAuPsyYFmNFRPkqXU7G+2xLi7KDw2J0YUduXlCP37/8hYOnyglJ8soWTiz2sMkV4/qzVXFvcjKrHo+wsBu7ShZODPuttoY2z8/+immb35rHr9lbJW9ikR78quXsP/oqbM6bCHNW1n5nkQQChf2zeO6i/pw84T+yWxWQiTrdZ9qZzeljcWfHRl9gQ/s1rbCWRjfnDEQM2P+zPNidruNrMyMKod3YplZtSFQ14CIZ2SfvFofVmoo53Vvx0X9OzXqY0pqWzDrPFrlZJIXDPpnZWZw7+wh9OzQ+J+4myuFRJJMG9yN8nMGriruzZXDI2f65rXK5osT+kXrlZ9WkMS9Z5GUNXt4T/5x7wxysvRWlih6ZpPoyhGRYJg0sEv01Lvff35Mxb2FIEmUESKSDCk3JpFOhvXuUOF0zXinbk4a2JU3tx2oMhAtItIYFBIp7suTzuGaMQXktwm/jkFEJFHS/nBT+QVdqSojwxQQIpI0aR8SX/rda1y4cCUfHk7doBARSZa0D4ln3t0NwMjvrWyQ+5s1tHuNde6cOqBBHktEJNHSPiQa2r8M7VHt9inndeUrk+NPgCYikmoUEg0sReaNExFpEAqJBlbdFdEiIk2NQqKBxUbEwiuHkN+m4jTDyhARaUoUEg0sNgSuGd2H1d+eyszgampJrNGFHfnUyF41VxSRWtPFdA0s3p7CL68dyVNrd/LF361p/AalkUdvrt+XzIhIVdqTaGBWwyxLOtokIk2JQqKBhY055GRFNlT5xjkRkRSmw02NZOKALtw5dQDXX9Q32U0REak1hUQDCzsFNiPDdBGdiDQ5CokGZsDfv3EpHx7RXFAi0vRpTKIB9OzQkkuKznz5eUGnVowoyEtii0REGoZCogG4e/SrSEVEmhOFRAPoGvOtcbqiWkSaE4VEPf10znB+c11xspshIpIQGriup9nDewLg6HiTiDQ/2pNoYDVdcS0i0pQoJEREJJRCooHo7CYRaY4UEg2kPCR0dpOINCcKCRERCaWQEBGRUAqJBqajTSLSnCQsJMzsh2b2rpm9ZWZ/MrMOMdvmm1mJma03s+kx5TOCshIzm5eotomISO0kck9iBTDE3YcCG4D5AGY2CJgDDAZmAL8ws0wzywR+DswEBgFXB3WbBF1MJyLNUcJCwt2fdvfSYPVloPwb6mcDj7j7CXd/HygBRgU/Je6+yd1PAo8EdZsWHW8SkWakscYkbgKeDJZ7Altjtm0LysLKE+ax1VtrriQiksbqNXeTma0EusXZtMDd/xLUWQCUAr8vv1mc+k78wIp7DMfM5gJzAQoKCs6y1Wfc9ce36nzbynQxnYg0R/UKCXefUt12M7se+Bgw2T36NroN6B1TrRfwQbAcVl75cRcDiwGKi4tT6u1ZczeJSHOSyLObZgDfBC5396Mxm5YCc8yshZkVAkXAKuBVoMjMCs0sh8jg9tJEtU9ERGqWyKnC/w1oAaywyFwVL7v7F919nZk9BvyDyGGoW939NICZ3QYsBzKBJe6+LoHta1AptTsjItJAEhYS7n5ONdsWAgvjlC8DliWqTYlU2Kk1q97fR7uW+ooOEWk+9I5WDy2zM6PL98wezPQhXRnco30SWyQi0rA0LUcdzRranVcWTI6u52ZnMmlg1yS2SESk4Skk6qhT6xza5WYnuxkiIgmlkKgjnegqIulAISEiIqEUEiIiEiptQ8LrOY+G6XtKRSQNpG1IiIhIzdI2JDQhn4hIzdI2JEREpGZpGxL13ZHQkISIpIO0DQkREalZ2oZEXc9uGtC1TQO3REQkdaVtSNTVfbOHADD1PM3TJCLNX9rOAlvXMYlBPdqxedGsBm2LiEiq0p6EiIiEStuQ0HUSIiI1S9uQqCtNxyEi6SRtQ8L1rdQiIjVK25Coq/pODCgi0pSkbUjovV5EpGZpGxIiIlIzhcRZ0sC1iKQThUQt/OXWcbTOyUx2M0REGl3ahsTZjElo+EJE0lXahkRd6ewmEUknaRsSZ3OdhEYhRCRdpW9I1HGHQAPXIpJO0jYkRESkZmkbEmezI9G+ZXbC2iEiksrSNiTORt/81slugohIUqRtSNT1LCWd3SQi6SThIWFmXzczN7P8YN3M7GdmVmJmb5nZBTF1rzezjcHP9Ylum4iIVC+hX19qZr2BqcCWmOKZQFHwMxr4JTDazDoC3wGKiQwZrDGzpe7+USLaVtf9AZ3dJCLpJNF7Eg8Cd1HxPXk28LBHvAx0MLPuwHRghbvvC4JhBTAjwe0TEZFqJCwkzOxyYLu7v1lpU09ga8z6tqAsrDwhNLQgIlKzeh1uMrOVQLc4mxYA3wKmxbtZnDKvpjze484F5gIUFBTUqq21cem5nXlu/Z5q62jgWkTSSb1Cwt2nxCs3s/OBQuDN4Bh+L+A1MxtFZA+hd0z1XsAHQfnESuXPhzzuYmAxQHFxcR1PU6pdtcuH9ajT3YuINAcJOdzk7m+7exd37+vufYkEwAXuvhNYClwXnOU0Bjjg7juA5cA0M8szszwieyHLE9E+qP3cTT+7ekSFdQ1ci0g6SejZTSGWAZcBJcBR4EYAd99nZvcBrwb17nX3fUloH/3yW7Np75FkPLSISEpplJAI9ibKlx24NaTeEmBJ47Sp4vqowo7ktcoB4L4rhnDNb19pjGaIiKS0ZOxJpCaHe2YPZmiv9ozt3ynZrRERSQnpOy1HnLK2udncMK4wOu7QodWZif3O79UegKwMjUmISPrQnkSIv375Yrq2y42uL76umI27DpObre+6FpH0kbYhUfl6h8pnOw3p2b7CervcbEb2yUt4u0REUknaHm6qzPQlpSIiVaRtSFQZk1BGiIhUkbYh8eLGStNvaLYNEZEq0jYk7ni04ryD14xpuDmgRESai7QNicr6d26T7CaIiKQchYSIiIRSSIiISCiFhIiIhFJIiIhIKIWEiIiEUkiIiEgohYSIiIRSSIiISCiFhIiIhFJIiIhIKIWEiIiEUkiIiEgohYSIiIRSSARc3ychIlKFQkJEREIpJAKmry8VEalCISEiIqEUEgGNSYiIVKWQEBGRUAoJEREJpZAIODreJCJSmUJCRERCKSQChs6BFRGpLKEhYWZfNrP1ZrbOzO6PKZ9vZiXBtukx5TOCshIzm5fItomISM2yEnXHZnYpMBsY6u4nzKxLUD4ImAMMBnoAK81sQHCznwNTgW3Aq2a21N3/kag2xtKYhIhIVQkLCeAWYJG7nwBw991B+WzgkaD8fTMrAUYF20rcfROAmT0S1G2UkBARkaoSebhpAHCJmb1iZi+Y2YVBeU9ga0y9bUFZWHkVZjbXzFab2eo9e/YkoOkiIgL13JMws5VAtzibFgT3nQeMAS4EHjOzfhB3hNiJH1hxjwG5+2JgMUBxcbGOE4mIJEi9QsLdp4RtM7NbgCfc3YFVZlYG5BPZQ+gdU7UX8EGwHFYuIiJJkMjDTX8GJgEEA9M5wF5gKTDHzFqYWSFQBKwCXgWKzKzQzHKIDG4vTWD7RESkBokcuF4CLDGztcBJ4Ppgr2KdmT1GZEC6FLjV3U8DmNltwHIgE1ji7usS2D4REalBwkLC3U8C14ZsWwgsjFO+DFiWqDaJiMjZ0RXXAU0VLiJSlUJCRERCKSQC+vpSEZGqFBIiIhJKIRHQmISISFUKCRERCaWQEBGRUAoJEREJpZAI9OnUKtlNEBFJOQqJQIdWOclugohIylFIiIhIKIWEiIiEUkiIiEgohYSIiIRSSIiISCiFhIiIhFJIiIhIKIWEiIiEUkiIiEgohYSIiIRSSIiISCiFhIiIhFJIiIhIKIWEiIiEUkiIiEgohYSIiIRSSIiISCiFhIiIhFJIiIhIKIWEiIiEUkiIiEiohIWEmQ03s5fN7A0zW21mo4JyM7OfmVmJmb1lZhfE3OZ6M9sY/FyfqLaJiEjtZCXwvu8H7nH3J83ssmB9IjATKAp+RgO/BEabWUfgO0Ax4MAaM1vq7h8lsI0iIlKNRB5ucqBdsNwe+CBYng087BEvAx3MrDswHVjh7vuCYFgBzEhg+0REpAaJ3JO4HVhuZj8iEkZjg/KewNaYetuCsrByERFJknqFhJmtBLrF2bQAmAzc4e6Pm9lVwEPAFMDi1PdqyuM97lxgLkBBQUEdWi4iIrVRr5Bw9ylh28zsYeCrwep/Ar8NlrcBvWOq9iJyKGobkTGL2PLnQx53MbAYoLi4OG6QiIhI/SVyTOIDYEKwPAnYGCwvBa4LznIaAxxw9x3AcmCameWZWR4wLSgTEZEkSeSYxBeAn5pZFnCc4PAQsAy4DCgBjgI3Arj7PjO7D3g1qHevu+9LYPtERKQGCQsJd/8vYGSccgduDbnNEmBJotokIiJnR1dci4hIKIWEiIiEUkiIiEgohYSIiIRSSIiISCiFBPDM1ybUXElEJA0l8jqJJmHzolnJboKISMrSnoSIiIRSSIiISCiFhIiIhFJIiIhIKIWEiIiEUkiIiEgohYSIiIRSSIiISCiFhIiIhFJIiIhIKIWEiIiEUkiIiEgohYSIiIRSSIiISCiFhIiIhFJIiIhIKIWEiIiEUkiIiEgohYSIiIRSSIiISKisZDcgWf70pbG8s+NQspshIpLS0jYkRhTkMaIgL9nNEBFJaTrcJCIioRQSIiISSiEhIiKh6hUSZvYpM1tnZmVmVlxp23wzKzGz9WY2PaZ8RlBWYmbzYsoLzewVM9toZo+aWU592iYiIvVX3z2JtcDHgb/HFprZIGAOMBiYAfzCzDLNLBP4OTATGARcHdQF+AHwoLsXAR8Bn6tn20REpJ7qFRLu/o67r4+zaTbwiLufcPf3gRJgVPBT4u6b3P0k8Agw28wMmAT8Mbj9/wWuqE/bRESk/hI1JtET2Bqzvi0oCyvvBOx399JK5SIikkQ1XidhZiuBbnE2LXD3v4TdLE6ZEz+UvJr6YW2aC8wFKCgoCKsmIiL1VGNIuPuUOtzvNqB3zHov4INgOV75XqCDmWUFexOx9eO1aTGwGMDM9pjZP+vQRoD84LHTifqcHtKtz+nWX6h/n/vUplKirrheCvzBzB4AegBFwCoiewxFZlYIbCcyuP0Zd3czew74JJFxiuuBsL2UCty9c10baWar3b245prNh/qcHtKtz+nWX2i8Ptf3FNgrzWwbcBHwNzNbDuDu64DHgH8ATwG3uvvpYC/hNmA58A7wWFAX4JvAnWZWQmSM4qH6tE1EROqvXnsS7v4n4E8h2xYCC+OULwOWxSnfROTsJxERSRHpfsX14mQ3IAnU5/SQbn1Ot/5CI/XZ3ENPIhIRkTSX7nsSIiJSjbQMibD5o5oiM1tiZrvNbG1MWUczWxHMg7XCzPKCcjOznwX9fsvMLoi5zfVB/Y1mdn0y+lJbZtbbzJ4zs3eCucO+GpQ3236bWa6ZrTKzN4M+3xOUx53zzMxaBOslwfa+MfcVd161VBRM5/O6mf01WG/W/QUws81m9raZvWFmq4Oy5L223T2tfoBM4D2gH5ADvAkMSna76tGf8cAFwNqYsvuBecHyPOAHwfJlwJNETkUeA7wSlHcENgW/84LlvGT3rZo+dwcuCJbbAhuIzAXWbPsdtL1NsJwNvBL05TFgTlD+K+CWYPlLwK+C5TnAo8HyoOA13wIoDP4XMpPdv2r6fSfwB+CvwXqz7m/Q5s1AfqWypL2203FPIu78UUluU525+9+BfZWKZxOZ/woqzoM1G3jYI14mcgFjd2A6sMLd97n7R8AKIhMzpiR33+HurwXLh4icTt2TZtzvoO2Hg9Xs4McJn/Ms9rn4IzA5mCMtbF61lGNmvYBZwG+D9ermeGvy/a1B0l7b6RgSYfNHNSdd3X0HRN5QgS5B+dnOqZXygsMKI4h8sm7W/Q4OvbwB7CbyT/8e4XOeRfsWbD9A5PqjptTnnwB3AWXBenVzvDWH/pZz4GkzW2ORKYggia/tdPyO67OaJ6qZCet7k3xOzKwN8Dhwu7sfjHxwjF81TlmT67e7nwaGm1kHItcnnRevWvC7SffZzD4G7Hb3NWY2sbw4TtVm0d9Kxrn7B2bWBVhhZu9WUzfh/U7HPYnq5pVqLnYFu5wEv3cH5WF9b3LPiZllEwmI37v7E0Fxs+83gLvvB54ncgy6g5mVf9iLbX+0b8H29kQOSzaVPo8DLjezzUQOCU8ismfRXPsb5e4fBL93E/kwMIokvrbTMSReJZg/KjgzYg6Ruaaak6VE5r+CivNgLQWuC86IGAMcCHZdlwPTzCwvOGtiWlCWkoJjzQ8B77j7AzGbmm2/zaxzsAeBmbUEphAZiymf8wyq9rn8ufgk8KxHRjSXAnOCs4EKOTOvWkpx9/nu3svd+xL5H33W3a+hmfa3nJm1NrO25ctEXpNrSeZrO9kj+cn4IXJGwAYix3QXJLs99ezLfwA7gFNEPj18jsix2GeAjcHvjkFdI/LNgO8BbwPFMfdzE5FBvRLgxmT3q4Y+X0xk1/kt4I3g57Lm3G9gKPB60Oe1wL8G5f2IvOmVAP8JtAjKc4P1kmB7v5j7WhA8F+uBmcnuWy36PpEzZzc16/4G/Xsz+FlX/v6UzNe2rrgWEZFQ6Xi4SUREakkhISIioRQSIiISSiEhIiKhFBIiIhJKISEiIqEUEiIiEkohISIiof4/vYnY0sftPZQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "agent = SARSALambdaAgent(env)\n",
    "\n",
    "# 训练\n",
    "episodes = 5000\n",
    "episode_rewards = []\n",
    "for episode in range(episodes):\n",
    "    episode_reward = play_sarsa(env, agent, train=True)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    \n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "# 测试\n",
    "agent.epsilon = 0. # 取消探索\n",
    "\n",
    "episode_rewards = [play_sarsa(env, agent, train=False) for _ in range(100)]\n",
    "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),\n",
    "        len(episode_rewards), np.mean(episode_rewards)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
